Skip to content

Commit

Permalink
Feat (ptq): for minifloat benchmark (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 4, 2023
1 parent 9d24ace commit b051309
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 147 deletions.
14 changes: 14 additions & 0 deletions src/brevitas/quant/experimental/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,17 @@ class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloa
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
scaling_per_output_channel = False
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from functools import partial
from itertools import product
import os
import random
from time import sleep
from types import SimpleNamespace

import numpy as np
Expand Down Expand Up @@ -38,54 +38,66 @@

config.IGNORE_MISSING_KEYS = True


def parse_type(v, default_type):
if v == 'None':
return None
else:
return default_type(v)


def parse_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y'):
return True
elif v.lower() in ('no', 'false', 'f', 'n'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


class hashabledict(dict):

def __hash__(self):
return hash(tuple(sorted(self.items())))


def unique(sequence):
seen = set()
return [x for x in sequence if not (x in seen or seen.add(x))]


# Torchvision models with top1 accuracy
TORCHVISION_TOP1_MAP = {
'resnet18': 69.758,
'mobilenet_v2': 71.898,
'vit_b_32': 75.912,}

OPTIONS = {
'model_name': TORCHVISION_TOP1_MAP.keys(),
'target_backend': ['fx', 'layerwise', 'flexml'], # Target backend
'scale_factor_type': ['float', 'po2'], # Scale factor type
'weight_bit_width': [8, 4], # Weight Bit Width
'act_bit_width': [8, 4], # Act bit width
'bias_bit_width': [None, 32, 16], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel
'act_quant_type': ['asym', 'sym'], # Act Quant Type
'weight_param_method': ['stats', 'mse'], # Weight Quant Type
'act_param_method': ['stats', 'mse'], # Act Param Method
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [0, 20], # Graph Equalization
'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization
'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant)
'learned_round': [False, True], # Enable/Disable Learned Round
'gptq': [False, True], # Enable/Disable GPTQ
'gptq_act_order': [False, True], # Use act_order euristics for GPTQ
'gpfq': [False, True], # Enable/Disable GPFQ
'gpfq_p': [0.25, 0.75], # GPFQ P
'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
}

OPTIONS_DEFAULT = {
'target_backend': ['fx'], # Target backend
'scale_factor_type': ['float'], # Scale factor type
'model_name': list(TORCHVISION_TOP1_MAP.keys()),
'quant_format': ['int'], # Quantization type (INT vs Float)
'target_backend': ['layerwise'], # Target backend
'scale_factor_type': ['float_scale'], # Scale factor type
'weight_mantissa_bit_width': [4],
'weight_exponent_bit_width': [3],
'act_mantissa_bit_width': [4],
'act_exponent_bit_width': [3],
'weight_bit_width': [8], # Weight Bit Width
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
'act_param_method': ['mse'], # Act Param Method
'weight_param_method': ['stats'], # Weight Quant Type
'act_param_method': ['stats'], # Act Param Method
'weight_param_method': ['mse'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [20], # Graph Equalization
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
'act_equalization': [None], # Perform Activation Equalization (Smoothquant)
'act_equalization': ['layerwise'], # Perform Activation Equalization (Smoothquant)
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
'gpfq': [False], # Enable/Disable GPFQ
'gpfq_p': [0.25], # GPFQ P
'gpfq_p': [0.75], # GPFQ P
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
Expand All @@ -108,8 +120,12 @@
parser.add_argument(
'--batch-size-validation', default=256, type=int, help='Minibatch size for validation')
parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size')
parser.add_argument(
'--options-to-exclude', choices=OPTIONS.keys(), nargs="+", default=[], help='Calibration size')
for option_name, option_value in OPTIONS_DEFAULT.items():
if isinstance(option_value[0], bool):
type_args = parse_bool
else:
type_args = partial(parse_type, default_type=type(option_value[0]))
parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args)


def main():
Expand All @@ -118,11 +134,9 @@ def main():
np.random.seed(SEED)
torch.manual_seed(SEED)

for option in args.options_to_exclude:
OPTIONS[option] = OPTIONS_DEFAULT[option]

args.gpu = get_gpu_index(args.idx)
print("Iter {}, GPU {}".format(args.idx, args.gpu))

try:
ptq_torchvision_models(args)
except Exception as E:
Expand All @@ -131,42 +145,25 @@ def main():

def ptq_torchvision_models(args):
# Generate all possible combinations, including invalid ones
# Split stats and mse due to the act_quant_percentile value

if 'stats' in OPTIONS['act_param_method']:
percentile_options = OPTIONS.copy()
percentile_options['act_param_method'] = ['stats']
else:
percentile_options = None
options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()}

if 'mse' in OPTIONS['act_param_method']:
mse_options = OPTIONS.copy()
mse_options['act_param_method'] = ['mse']
mse_options['act_quant_percentile'] = [None]
else:
mse_options = None

# Combine MSE and Percentile combinations, if they are defined
combinations = []
if mse_options is not None:
combinations += list(product(*mse_options.values()))
if percentile_options is not None:
combinations += list(product(*percentile_options.values()))
# Combine the two sets of combinations
# Generate Namespace for each configuration
configs = [
SimpleNamespace(**{k: v
for k, v in zip(OPTIONS.keys(), combination)})
for combination in combinations]
# Define which configurations are not valid
configs = list(map(validate_config, configs))
# Drop invalid configurations
configs = list(config for config in configs if config.is_valid)

if args.idx > len(configs):
combinations = list(product(*options.values()))

configs = []
for combination in combinations:
config_namespace = SimpleNamespace(
**{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)})
config_namespace = validate_config(config_namespace)
if config_namespace.is_valid:
configs.append(hashabledict(**config_namespace.__dict__))

configs = unique(configs)

if args.idx > len(configs) - 1:
return

config_namespace = configs[args.idx]
config_namespace = SimpleNamespace(**configs[args.idx])
print(config_namespace)

fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name]
Expand Down Expand Up @@ -221,8 +218,13 @@ def ptq_torchvision_models(args):
# Define the quantized model
quant_model = quantize_model(
model,
quant_format=config_namespace.quant_format,
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width,
weight_exponent_bit_width=config_namespace.weight_exponent_bit_width,
act_mantissa_bit_width=config_namespace.act_mantissa_bit_width,
act_exponent_bit_width=config_namespace.act_exponent_bit_width,
weight_bit_width=config_namespace.weight_bit_width,
weight_param_method=config_namespace.weight_param_method,
act_param_method=config_namespace.act_param_method,
Expand Down Expand Up @@ -298,7 +300,7 @@ def validate_config(config_namespace):
# Flexml supports only per-tensor scale factors, power of two scale factors
if config_namespace.target_backend == 'flexml' and (
config_namespace.weight_quant_granularity == 'per_channel' or
config_namespace.scale_factor_type == 'float32'):
config_namespace.scale_factor_type == 'float_scale'):
is_valid = False
# Merge bias can be enabled only when graph equalization is enabled
if config_namespace.graph_eq_iterations == 0 and config_namespace.graph_eq_merge_bias:
Expand All @@ -311,15 +313,33 @@ def validate_config(config_namespace):
if not config_namespace.gptq and config_namespace.gptq_act_order:
is_valid = False

# If GPFQ is disabled, we execute only one configuration for p==0.25
if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75:
is_valid = False

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
is_valid = False
if config_namespace.act_bit_width < config_namespace.weight_bit_width:
is_valid = False

if config_namespace.act_param_method == 'mse':
config_namespace.act_quant_percentile = None

if not config_namespace.gpfq:
config_namespace.gpfq_p = None

if config_namespace.quant_format == 'int':
config_namespace.weight_mantissa_bit_width = None
config_namespace.weight_exponent_bit_width = None
config_namespace.act_mantissa_bit_width = None
config_namespace.act_exponent_bit_width = None

if config_namespace.quant_format == 'float':
config_namespace.act_quant_type = 'sym'
config_namespace.weight_quant_type = 'sym'

if config_namespace.quant_format == 'float':
if config_namespace.weight_exponent_bit_width + config_namespace.weight_mantissa_bit_width != config_namespace.weight_bit_width - 1:
is_valid = False
if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1:
is_valid = False

config_namespace.is_valid = is_valid
return config_namespace

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val --options-to-exclude graph_eq_merge_bias graph_eq_iterations
python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \
--quant_format float \
--scale_factor_type float_scale \
--weight_bit_width 2 3 4 5 6 7 8 \
--act_bit_width 2 3 4 5 6 7 8 \
--weight_mantissa_bit_width 1 2 3 4 5 6 \
--weight_exponent_bit_width 1 2 3 4 5 6 \
--act_mantissa_bit_width 1 2 3 4 5 6 \
--act_exponent_bit_width 1 2 3 4 5 6 \
--bias_bit_width None \
--weight_quant_granularity per_channel per_tensor \
--act_quant_type sym \
--weight_param_method stats \
--act_param_method mse \
--bias_corr True \
--graph_eq_iterations 20 \
--graph_eq_merge_bias True \
--act_equalization layerwise \
--learned_round False \
--gptq False \
--gptq_act_order False \
--gpfq False \
--gpfq_p None \
--uint_sym_act_for_unsigned_values False \
--act_quant_percentile None \
Loading

0 comments on commit b051309

Please sign in to comment.