Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (ptq): Adding A2Q Upper Bound clipping to GPFQ #734

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]

Expand Down Expand Up @@ -170,33 +171,15 @@ def __init__(
)
self.accumulator_bit_width = accumulator_bit_width_impl

@brevitas.jit.script_method
def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
# This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1}
# and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow
# on the positive side of this range.
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
# This is the maximum possible magnitude that the input data could take. When the data is signed,
# this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here
# to simplify our derivations on the export validation.
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
T = get_upper_bound_on_l1_norm(
self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
# re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val
Expand Down
15 changes: 15 additions & 0 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,18 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val


def get_upper_bound_on_l1_norm(
accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse
142 changes: 126 additions & 16 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
Expand Down Expand Up @@ -45,7 +46,9 @@ def __init__(
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False) -> None:
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -61,6 +64,10 @@ def __init__(
self.model.forward = self.catch_stopfwd
self.p = p

# GPFA2Q params
self.use_gpfa2q = use_gpfa2q
self.accumulator_bit_width = accumulator_bit_width

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
Expand Down Expand Up @@ -96,28 +103,31 @@ def catch_stopfwd(self, *args, **kwargs):

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
if not self.use_gpfa2q:
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
else:
return GPFA2Q(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width)


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""

def __init__(
self,
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
p=1.0) -> None:
def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

Expand Down Expand Up @@ -256,3 +266,103 @@ def single_layer_update(self):

del self.float_input
del self.quantized_input


class GPFA2Q(GPFQ):

def __init__(
self,
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
accumulator_bit_width,
p) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
self.accumulator_bit_width = accumulator_bit_width
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
assert self.accumulator_bit_width is not None

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError(
'Expected quant input to calculate Upper Bound on L1 norm, but received None')
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)

# get upper bound
input_bit_width = self.quant_input.bit_width
input_is_signed = self.quant_input.signed
T = get_upper_bound_on_l1_norm(
torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()
s = s.view(self.groups, -1) # [Groups, OC/Groups]

l1_norm = torch.zeros(weight.shape[:-1], device=dev)

# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(
self.quantized_input[group_index, :,
permutation_list[group_index][t]]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)

for group_index in range(self.groups):
candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index])
candidate_l1_mask = candidate_l1 > T * s[group_index]
if torch.any(candidate_l1_mask):
# set all values to 0 that are exceeding T * s
weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0
q[group_index][candidate_l1_mask] = 0
else:
l1_norm[group_index] = candidate_l1
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :,
permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
del self.quantized_input
9 changes: 2 additions & 7 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,8 @@ class GPTQ(GPxQ):
"""

def __init__(
self,
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
num_blocks=100) -> None:
self, layer, name, act_order, len_parallel_layers, create_weight_orig,
num_blocks) -> None:
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

dev = self.layer.weight.device
Expand Down
14 changes: 13 additions & 1 deletion src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import List, Optional, Set
import warnings

import torch

from brevitas.graph.calibrate import DisableEnableQuantization
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
Expand Down Expand Up @@ -175,13 +177,23 @@ def process_input(self, inp):
if self.layer.weight_quant_requires_quant_input:
# Can minimize memory allocation by not storing actual values
self.quant_input = QuantTensor(
value=None,
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
signed=inp.signed,
training=inp.training)
inp = inp.value
elif self.layer.is_input_quant_enabled:
self.quant_input = QuantTensor(
value=torch.empty(
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=self.layer.quant_input_scale(),
zero_point=self.layer.quant_input_zero_point(),
bit_width=self.layer.quant_input_bit_width(),
signed=self.layer.is_quant_input_signed,
training=self.layer.training)

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def unique(sequence):
'scale_factor_type': ['float_scale'], # Scale factor type
'weight_mantissa_bit_width': [4],
'weight_exponent_bit_width': [3],
'weight_narrow_range': [False],
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
'layerwise_first_last_bit_width': [8], # Input and weights bit width for first and last layer
'act_mantissa_bit_width': [4],
'act_exponent_bit_width': [3],
'weight_bit_width': [8], # Weight Bit Width
Expand All @@ -95,10 +97,12 @@ def unique(sequence):
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
'act_equalization': ['layerwise'], # Perform Activation Equalization (Smoothquant)
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
'gptq': [False], # Enable/Disable GPTQ
'gpfq': [False], # Enable/Disable GPFQ
'gpfq_p': [0.75], # GPFQ P
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'gpfa2q': [False], # Enable/Disable GPFA2Q
'gpfq_p': [1.0], # GPFQ P
'gpxq_act_order': [False], # Use act_order euristics for GPxQ
'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
}
Expand Down Expand Up @@ -221,6 +225,8 @@ def ptq_torchvision_models(args):
quant_format=config_namespace.quant_format,
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
layerwise_first_last_bit_width=config_namespace.layerwise_first_last_bit_width,
weight_narrow_range=config_namespace.weight_narrow_range,
weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width,
weight_exponent_bit_width=config_namespace.weight_exponent_bit_width,
act_mantissa_bit_width=config_namespace.act_mantissa_bit_width,
Expand All @@ -247,11 +253,25 @@ def ptq_torchvision_models(args):

if config_namespace.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p)
apply_gpfq(
calib_loader,
quant_model,
p=config_namespace.gpfq_p,
act_order=config_namespace.gpxq_act_order)

if config_namespace.gpfa2q:
print("Performing GPFA2Q:")
apply_gpfq(
calib_loader,
quant_model,
p=config_namespace.gpfq_p,
act_order=config_namespace.gpxq_act_order,
gpfa2q=config_namespace.gpfa2q,
accumulator_bit_width=config_namespace.accumulator_bit_width)

if config_namespace.gptq:
print("Performing gptq")
apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order)
apply_gptq(calib_loader, quant_model, config_namespace.gpxq_act_order)

if config_namespace.learned_round:
print("Applying Learned Round:")
Expand Down Expand Up @@ -309,8 +329,10 @@ def validate_config(config_namespace):
if (config_namespace.target_backend == 'fx' or config_namespace.target_backend
== 'layerwise') and config_namespace.bias_bit_width == 16:
is_valid = False
# If GPTQ is disabled, we do not care about the act_order heuristic
if not config_namespace.gptq and config_namespace.gptq_act_order:
# Only one of GPTQ, GPFQ, or GPA2Q can be enabled, or none
multiple_gpxqs = float(config_namespace.gpfq) + float(config_namespace.gptq) + float(
config_namespace.gpfa2q)
if multiple_gpxqs > 1:
is_valid = False

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
Expand All @@ -320,9 +342,12 @@ def validate_config(config_namespace):

if config_namespace.act_param_method == 'mse':
config_namespace.act_quant_percentile = None

if not config_namespace.gpfq:
# gpfq_p is needed for GPFQ and GPFA2Q
if not config_namespace.gpfq and not config_namespace.gpfa2q:
config_namespace.gpfq_p = None
# accumulator bit width is not needed when not GPFA2Q
if not config_namespace.gpfa2q:
config_namespace.accumulator_bit_width = None

if config_namespace.quant_format == 'int':
config_namespace.weight_mantissa_bit_width = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/image
--act_equalization layerwise \
--learned_round False \
--gptq False \
--gptq_act_order False \
--gpxq_act_order False \
--gpfq False \
--gpfq_p None \
--gpfa2q False \
--accumulator_bit_width None \
--uint_sym_act_for_unsigned_values False \
--act_quant_percentile None \
Loading
Loading