Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 25, 2023
1 parent b015669 commit ebb87e7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
37 changes: 19 additions & 18 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from abc import ABC
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import field
Expand All @@ -12,12 +13,13 @@
import warnings

import torch

try:
from torch.linalg import LinAlgError
except:
LinAlgError = RuntimeError
import unfoldNd
LinAlgError = RuntimeError
import numpy as np
import unfoldNd

from brevitas.graph.calibrate import DisableEnableQuantization
import brevitas.nn as qnn
Expand All @@ -38,6 +40,7 @@ class LayerHandler:


class gpxq_mode(ABC):

def __init__(
self,
model,
Expand Down Expand Up @@ -104,11 +107,7 @@ def __enter__(self):
# Attach hooks for GPTQ
if self._is_module_supported(module):
gpxq = self.class_implementation(
module,
name,
# num_blocks=self.num_blocks,
act_order=self.act_order,
parallel_layers=parallel_layers)
module, name, act_order=self.act_order, parallel_layers=parallel_layers)
hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
self.gpxq_layers[name] = gpxq
Expand Down Expand Up @@ -139,6 +138,7 @@ def update(self):
def catch_stopfwd(self, *args, **kwargs):
pass


class gptq_mode(gpxq_mode):
"""
Apply GPTQ algorithm https://arxiv.org/abs/2210.17323.
Expand Down Expand Up @@ -185,6 +185,7 @@ def catch_stopfwd(self, *args, **kwargs):
except StopFwdException:
pass


class gpfq_mode(gpxq_mode):
"""
Apply GPTQ algorithm https://arxiv.org/abs/2210.17323.
Expand Down Expand Up @@ -245,7 +246,7 @@ def catch_stopfwd(self, *args, **kwargs):
for module in self.model.modules():
if hasattr(module, 'weight_orig_data'):
module.weight.data = quant_weight[module]
# Re-enable quantization. If activation quantization is disabled,
# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
Expand Down Expand Up @@ -276,7 +277,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.parallel_layers = parallel_layers

def process_input(self, inp):
# Input is a tuple, so we take first element
inp = inp[0]
Expand All @@ -299,6 +300,7 @@ def process_input(self, inp):
@abstractmethod
def update_batch(self, module, input, current_layer):
pass

@abstractmethod
def single_layer_update(self, percdamp=.01):
pass
Expand All @@ -307,7 +309,7 @@ def get_quant_weights(self, i, permutation_list, i1=0):
# We need to recompute quant weights at runtime since our float weights are being updated

# Add offset in case of blockwise computation (e.g., GPTQ)
i = i1+i
i = i1 + i

# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
# of quantizing only a subset of the entire matrix speeding up the computation of GPTQ
Expand Down Expand Up @@ -562,12 +564,11 @@ def single_layer_update(self, percdamp=.01):
weight[:, perm[i2:]] -= (error_block.matmul(h_inv[0, i1:i2, i2:])).to(dtype)




class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""

def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
self.layer = layer
self.name = name
Expand All @@ -587,7 +588,6 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]


# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
self.nsamples = 0
self.parallel_layers = parallel_layers
Expand All @@ -596,9 +596,8 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None:
self.index_computed = False
self.p = 0.25


def update_batch(self, module, input, current_layer):

# Update reference to current layer
current_layer.layer_names.add(self.name)
is_quant_disabled = module.weight_quant.disable_quant
Expand Down Expand Up @@ -695,7 +694,8 @@ def single_layer_update(self):
for i in range(self.groups):
self.quant_input = self.quant_input[i]
self.float_input = self.float_input[i]
U[i:i+1] += weight[i:i+1, t].unsqueeze(1) * self.float_input[i, t].unsqueeze(0)
U[i:i +
1] += weight[i:i + 1, t].unsqueeze(1) * self.float_input[i, t].unsqueeze(0)
norm = torch.linalg.norm(self.quant_input[:, t], 2) ** 2
if norm > 0:
q_arg = U.matmul(self.quant_input[:, t]) / norm
Expand All @@ -705,7 +705,8 @@ def single_layer_update(self):
weight[:, t] = q_arg
q = self.get_quant_weights(t, [torch.tensor(range(weight.shape[1]))])
for i in range(self.groups):
U[i:i+1] -= q[i:i+1].unsqueeze(1) * self.quant_input[i:i+1, t].unsqueeze(0)
U[i:i +
1] -= q[i:i + 1].unsqueeze(1) * self.quant_input[i:i + 1, t].unsqueeze(0)

del self.float_input
del self.quant_input
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from brevitas.graph.gpxq import gptq_mode, gpfq_mode

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -13,6 +12,8 @@
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import norm_correction_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.gpxq import gpfq_mode
from brevitas.graph.gpxq import gptq_mode
from brevitas.graph.quantize import COMPUTE_LAYER_MAP
from brevitas.graph.quantize import LAYERWISE_COMPUTE_LAYER_MAP
from brevitas.graph.quantize import layerwise_quantize
Expand Down Expand Up @@ -256,7 +257,7 @@ def apply_gptq(calib_loader, model, act_order=False):
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq:
gptq_model = gptq.model
for i in tqdm(range(gptq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization, apply_gpfq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning
from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate
Expand Down Expand Up @@ -154,7 +155,7 @@
default=True,
help='Narrow range for weight quantization (default: enabled)')
add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
Expand Down Expand Up @@ -287,7 +288,7 @@ def main():
if args.gptq:
print("Performing GPTQ:")
apply_gptq(calib_loader, quant_model, args.gptq_act_order)

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model)
Expand Down

0 comments on commit ebb87e7

Please sign in to comment.