Skip to content

Commit

Permalink
Fix for bitwidth, support MSE
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 4, 2023
1 parent 601e70a commit e982285
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 16 deletions.
14 changes: 14 additions & 0 deletions src/brevitas/quant/experimental/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,17 @@ class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloa
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
config.IGNORE_MISSING_KEYS = True


def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


class hashabledict(dict):

def __hash__(self):
Expand Down Expand Up @@ -66,11 +77,11 @@ def unique(sequence):
'act_exponent_bit_width': [3],
'weight_bit_width': [8], # Weight Bit Width
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'bias_bit_width': [None], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
'act_param_method': ['mse'], # Act Param Method
'weight_param_method': ['stats'], # Weight Quant Type
'act_param_method': ['stats'], # Act Param Method
'weight_param_method': ['mse'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [20], # Graph Equalization
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
Expand Down Expand Up @@ -102,8 +113,12 @@ def unique(sequence):
'--batch-size-validation', default=256, type=int, help='Minibatch size for validation')
parser.add_argument('--calibration-samples', default=1000, type=int, help='Calibration size')
for option_name, option_value in OPTIONS_DEFAULT.items():
parser.add_argument(
f'--{option_name}', default=option_value, nargs="+", type=type(option_value[0]))
if isinstance(option_value[0], bool):
type_args = str2bool
else:
type_args = type(option_value[0])
parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args)
print(parser)


def main():
Expand Down Expand Up @@ -138,7 +153,7 @@ def ptq_torchvision_models(args):

configs = unique(configs)

if args.idx > len(configs):
if args.idx > len(configs) - 1:
return

config_namespace = SimpleNamespace(**configs[args.idx])
Expand Down
29 changes: 19 additions & 10 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import brevitas.nn as qnn
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand Down Expand Up @@ -85,7 +88,12 @@
'per_tensor': {
'sym': Fp8e4m3WeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloat}}}}}
'sym': Fp8e4m3WeightPerChannelFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3WeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand All @@ -108,7 +116,10 @@
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat},}}},}
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat},}}}}


def quantize_model(
Expand Down Expand Up @@ -184,15 +195,15 @@ def layerwise_bit_width_fn_weight(module):
# Missing fix for name_shadowing for all variables
weight_bit_width_dict = {}
act_bit_width_dict = {}
if weight_quant_format == 'int' and backend == 'layerwise':
if quant_format == 'int' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act

else:
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width

if weight_quant_format == 'float' and backend == 'layerwise':
if quant_format == 'float' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
Expand Down Expand Up @@ -271,15 +282,13 @@ def kwargs_prefix(prefix, weight_kwargs):

weight_bit_width_dict = {'bit_width': weight_bit_width}
if weight_quant_format == 'float':
weight_bit_width_dict = {
'exponent_bit_width': weight_exponent_bit_width,
'mantissa_bit_width': weight_mantissa_bit_width}
weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width
weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width

act_bit_width_dict = {'bit_width': act_bit_width}
if act_quant_format == 'float':
act_bit_width_dict = {
'exponent_bit_width': act_exponent_bit_width,
'mantissa_bit_width': act_mantissa_bit_width}
act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

# Retrieve base input, weight, and bias quantizers
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]
Expand Down

0 comments on commit e982285

Please sign in to comment.