Skip to content

Commit

Permalink
Feat (examples/ptq): support for dynamic act quant (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored May 14, 2024
1 parent a1926f0 commit 3464ec7
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 30 deletions.
14 changes: 10 additions & 4 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import functools
import math
from typing import Callable
import warnings

import torch
from torch import Tensor
import torch.nn.functional as F

import brevitas
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.utils.torch_utils import compute_channel_view_shape
Expand Down Expand Up @@ -358,11 +359,16 @@ def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training):
training=training)


def quant_output_scale_impl(fn, inp, quant_input_scale, quant_weight_scale):
def quant_output_scale_impl(
fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
channel_dim = -1 if fn == F.linear else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)

quant_weight_scale = quant_weight_scale.view(output_scale_shape)
if len(quant_input_scale.shape) == 0:
quant_input_scale = quant_input_scale.view(output_scale_shape)

output_scale = quant_weight_scale * quant_input_scale
return output_scale


Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(self, x) -> Tensor:
shape = x.shape
x = self.scaling_stats_input_view_shape_impl(x)
x = self.stats_impl(x)

x = self.dynamic_scaling_broadcastable_fn(x, shape)
return x

Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/imagenet_classification/ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--weight-quant-calibration-type {stats,mse}]
[--act-equalization {fx,layerwise,None}]
[--act-quant-calibration-type {stats,mse}]
[--act-scale-computation-type {static,dynamic}]
[--graph-eq-iterations GRAPH_EQ_ITERATIONS]
[--learned-round-iters LEARNED_ROUND_ITERS]
[--learned-round-lr LEARNED_ROUND_LR]
Expand Down Expand Up @@ -184,6 +185,9 @@ options:
--act-quant-calibration-type {stats,mse}
Activation quantization calibration type (default:
stats)
--act-scale-computation-type {static,dynamic}
Activation quantization scale computation type
(default: static)
--graph-eq-iterations GRAPH_EQ_ITERATIONS
Numbers of iterations for graph equalization (default:
20)
Expand Down
78 changes: 55 additions & 23 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.backends.cudnn as cudnn
from tqdm import tqdm

from brevitas.core.function_wrapper.shape import OverBatchOverTensorView
from brevitas.core.scaling.standalone import ParameterFromStatsFromParameterScaling
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.calibrate import bias_correction_mode
Expand Down Expand Up @@ -49,10 +50,28 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers


# Every element of the Batch will have its own scale factor and zero point
class CNNShiftedUint8DynamicActPerTensorFloat(ShiftedUint8DynamicActPerTensorFloat):
scaling_stats_input_view_shape_impl = OverBatchOverTensorView
scaling_stats_permute_dims = None
stats_reduce_dim = 1
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))])


class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
scaling_stats_input_view_shape_impl = OverBatchOverTensorView
scaling_stats_permute_dims = None
stats_reduce_dim = 1
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))])


QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize, 'flexml': quantize_flexml}

BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None}
Expand Down Expand Up @@ -98,21 +117,29 @@

INPUT_QUANT_MAP = {
'int': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint},
},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE}},}},
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE,
'asym': ShiftedUint8ActPerTensorFloatMSE}}},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint,
'asym': ShiftedUint8ActPerTensorFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE}}}},
'dynamic': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': CNNInt8DynamicActPerTensorFloat,
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
'float': {
'float_scale': {
'stats': {
Expand Down Expand Up @@ -146,6 +173,7 @@ def quantize_model(
act_param_method='stats',
weight_quant_type='sym',
act_quant_granularity='per_tensor',
act_scale_computation_type='dynamic',
uint_sym_act_for_unsigned_values=True,
dtype=torch.float32,
device='cpu'):
Expand All @@ -165,8 +193,10 @@ def quantize_model(
weight_mantissa_bit_width,
weight_exponent_bit_width,
act_mantissa_bit_width,
act_exponent_bit_width,
)
act_exponent_bit_width)

if act_scale_computation_type == 'dynamic':
assert bias_bit_width is None, "Bias quantization is not supported with dynamic activation quantization"

weight_quant_format = quant_format
act_quant_format = quant_format
Expand Down Expand Up @@ -253,6 +283,7 @@ def layerwise_bit_width_fn_weight(module):
act_quant_type=act_quant_type,
act_quant_granularity=act_quant_granularity,
act_quant_percentile=act_quant_percentile,
act_scale_computation_type=act_scale_computation_type,
**weight_bit_width_dict,
**act_bit_width_dict)

Expand Down Expand Up @@ -288,6 +319,7 @@ def create_quant_maps(
act_exponent_bit_width=None,
act_bit_width=None,
act_scale_type=None,
act_scale_computation_type=None,
act_param_method=None,
act_quant_type=None,
act_quant_granularity=None,
Expand Down Expand Up @@ -317,14 +349,14 @@ def kwargs_prefix(prefix, weight_kwargs):
weight_quant = weight_quant.let(**weight_bit_width_dict)

if act_bit_width is not None:
act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][
act_quant_granularity][act_quant_type]
act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][act_scale_type][
act_param_method][act_quant_granularity][act_quant_type]
# Some activations in MHA should always be symmetric
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][
act_quant_granularity]['sym']
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method][act_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][
'per_tensor'][act_quant_type]
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method]['per_tensor'][act_quant_type]
act_quant = act_quant.let(**act_bit_width_dict)
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
Expand Down
13 changes: 10 additions & 3 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def parse_type(v, default_type):
default='stats',
choices=['stats', 'mse'],
help='Activation quantization calibration type (default: stats)')
parser.add_argument(
'--act-scale-computation-type',
default='static',
choices=['static', 'dynamic'],
help='Activation quantization scale computation type (default: static)')
parser.add_argument(
'--graph-eq-iterations',
default=20,
Expand Down Expand Up @@ -411,11 +416,13 @@ def main():
weight_exponent_bit_width=args.weight_exponent_bit_width,
act_mantissa_bit_width=args.act_mantissa_bit_width,
act_exponent_bit_width=args.act_exponent_bit_width,
act_scale_computation_type=args.act_scale_computation_type,
uint_sym_act_for_unsigned_values=args.uint_sym_act_for_unsigned_values)

# Calibrate the quant_model on the calibration dataloader
print("Starting activation calibration:")
calibrate(calib_loader, quant_model)
if args.act_scale_computation_type == 'static':
# Calibrate the quant_model on the calibration dataloader
print("Starting activation calibration:")
calibrate(calib_loader, quant_model)

if args.gpfq:
print("Performing GPFQ:")
Expand Down

0 comments on commit 3464ec7

Please sign in to comment.