Skip to content

Commit

Permalink
initial support for minifloat
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 27, 2023
1 parent 3b7b9c7 commit 11caf5b
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import product
import os
import random
from time import sleep
from types import SimpleNamespace

import numpy as np
Expand Down Expand Up @@ -38,38 +37,33 @@

config.IGNORE_MISSING_KEYS = True


class hashabledict(dict):

def __hash__(self):
return hash(tuple(sorted(self.items())))


def unique(sequence):
seen = set()
return [x for x in sequence if not (x in seen or seen.add(x))]


# Torchvision models with top1 accuracy
TORCHVISION_TOP1_MAP = {
'resnet18': 69.758,
'mobilenet_v2': 71.898,
'vit_b_32': 75.912,}

OPTIONS = {
'model_name': TORCHVISION_TOP1_MAP.keys(),
'target_backend': ['fx', 'layerwise', 'flexml'], # Target backend
'scale_factor_type': ['float', 'po2'], # Scale factor type
'weight_bit_width': [8, 4], # Weight Bit Width
'act_bit_width': [8, 4], # Act bit width
'bias_bit_width': [32, 16], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel
'act_quant_type': ['asym', 'sym'], # Act Quant Type
'weight_param_method': ['stats', 'mse'], # Weight Quant Type
'act_param_method': ['stats', 'mse'], # Act Param Method
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [0, 20], # Graph Equalization
'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization
'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant)
'learned_round': [False, True], # Enable/Disable Learned Round
'gptq': [False, True], # Enable/Disable GPTQ
'gptq_act_order': [False, True], # Use act_order euristics for GPTQ
'gpfq': [False, True], # Enable/Disable GPFQ
'gpfq_p': [0.25, 0.75], # GPFQ P
'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile
}

OPTIONS_DEFAULT = {
'model_name': list(TORCHVISION_TOP1_MAP.keys()),
'quant_format': ['int'], # Quantization type (INT vs Float)
'target_backend': ['fx'], # Target backend
'scale_factor_type': ['float'], # Scale factor type
'scale_factor_type': ['float_scale'], # Scale factor type
'weight_mantissa_bit_width': [4],
'weight_exponent_bit_width': [3],
'act_mantissa_bit_width': [4],
'act_exponent_bit_width': [3],
'weight_bit_width': [8], # Weight Bit Width
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
Expand All @@ -84,7 +78,7 @@
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
'gpfq': [False], # Enable/Disable GPFQ
'gpfq_p': [0.25], # GPFQ P
'gpfq_p': [0.75], # GPFQ P
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'act_quant_percentile': [99.999], # Activation Quantization Percentile
}
Expand All @@ -106,8 +100,9 @@
parser.add_argument(
'--batch-size-validation', default=256, type=int, help='Minibatch size for validation')
parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size')
parser.add_argument(
'--options-to-exclude', choices=OPTIONS.keys(), nargs="+", default=[], help='Calibration size')
for option_name, option_value in OPTIONS_DEFAULT.items():
parser.add_argument(
f'--{option_name}', default=option_value, nargs="+", type=type(option_value[0]))


def main():
Expand All @@ -116,11 +111,9 @@ def main():
np.random.seed(SEED)
torch.manual_seed(SEED)

for option in args.options_to_exclude:
OPTIONS[option] = OPTIONS_DEFAULT[option]

args.gpu = get_gpu_index(args.idx)
print("Iter {}, GPU {}".format(args.idx, args.gpu))

try:
ptq_torchvision_models(args)
except Exception as E:
Expand All @@ -129,43 +122,25 @@ def main():

def ptq_torchvision_models(args):
# Generate all possible combinations, including invalid ones
# Split stats and mse due to the act_quant_percentile value

if 'stats' in OPTIONS['act_param_method']:
percentile_options = OPTIONS.copy()
percentile_options['act_param_method'] = ['stats']
else:
percentile_options = None
options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()}

if 'mse' in OPTIONS['act_param_method']:
mse_options = OPTIONS.copy()
mse_options['act_param_method'] = ['mse']
mse_options['act_quant_percentile'] = [None]
else:
mse_options = None

# Combine MSE and Percentile combinations, if they are defined
combinations = []
if mse_options is not None:
combinations += list(product(*mse_options.values()))
if percentile_options is not None:
combinations += list(product(*percentile_options.values()))
# Combine the two sets of combinations
# Generate Namespace for each configuration
configs = [
SimpleNamespace(**{k: v
for k, v in zip(OPTIONS.keys(), combination)})
for combination in combinations]
# Define which configurations are not valid
configs = list(map(validate_config, configs))
# Drop invalid configurations
configs = list(config for config in configs if config.is_valid)
combinations = list(product(*options.values()))

configs = []
for combination in combinations:
config_namespace = SimpleNamespace(
**{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)})
config_namespace = validate_config(config_namespace)
if config_namespace.is_valid:
configs.append(hashabledict(**config_namespace.__dict__))

configs = unique(configs)

if args.idx > len(configs):
return

config_namespace = configs[args.idx]
print(config_namespace)
config_namespace = SimpleNamespace(**configs[args.idx])

fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name]
# Get model-specific configurations about input shapes and normalization
Expand Down Expand Up @@ -219,6 +194,7 @@ def ptq_torchvision_models(args):
# Define the quantized model
quant_model = quantize_model(
model,
quant_format=config_namespace.quant_format,
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
weight_bit_width=config_namespace.weight_bit_width,
Expand Down Expand Up @@ -295,7 +271,7 @@ def validate_config(config_namespace):
# Flexml supports only per-tensor scale factors, power of two scale factors
if config_namespace.target_backend == 'flexml' and (
config_namespace.weight_quant_granularity == 'per_channel' or
config_namespace.scale_factor_type == 'float32'):
config_namespace.scale_factor_type == 'float_scale'):
is_valid = False
# Merge bias can be enabled only when graph equalization is enabled
if config_namespace.graph_eq_iterations == 0 and config_namespace.graph_eq_merge_bias:
Expand All @@ -308,15 +284,27 @@ def validate_config(config_namespace):
if not config_namespace.gptq and config_namespace.gptq_act_order:
is_valid = False

# If GPFQ is disabled, we execute only one configuration for p==0.25
if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75:
is_valid = False

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
is_valid = False
if config_namespace.act_bit_width < config_namespace.weight_bit_width:
is_valid = False

if config_namespace.act_param_method == 'mse':
config_namespace.act_quant_percentile = None

if not config_namespace.gpfq:
config_namespace.gpfq_p = None

if config_namespace.quant_format == 'int':
config_namespace.weight_mantissa_bit_width = None
config_namespace.weight_exponent_bit_width = None
config_namespace.act_mantissa_bit_width = None
config_namespace.act_exponent_bit_width = None

if config_namespace.quant_format == 'float':
config_namespace.act_quant_type = 'sym'
config_namespace.weight_quant_type = 'sym'

config_namespace.is_valid = is_valid
return config_namespace

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val --options-to-exclude graph_eq_merge_bias graph_eq_iterations
python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \
--graph_eq_iterations 50 \
--act_param_method stats mse \
--act_quant_percentile 99.9 99.99
Loading

0 comments on commit 11caf5b

Please sign in to comment.