From ebb87e7e69d95debc2d45372ac50f543870d16be Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 25 Jul 2023 12:16:14 +0100 Subject: [PATCH] Review --- src/brevitas/graph/gpxq.py | 37 ++++++++++--------- .../imagenet_classification/ptq/ptq_common.py | 5 ++- .../ptq/ptq_evaluate.py | 7 ++-- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index d2e190936..5611ecfe6 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -1,7 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC, abstractmethod +from abc import ABC +from abc import abstractmethod from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -12,12 +13,13 @@ import warnings import torch + try: from torch.linalg import LinAlgError except: - LinAlgError = RuntimeError -import unfoldNd + LinAlgError = RuntimeError import numpy as np +import unfoldNd from brevitas.graph.calibrate import DisableEnableQuantization import brevitas.nn as qnn @@ -38,6 +40,7 @@ class LayerHandler: class gpxq_mode(ABC): + def __init__( self, model, @@ -104,11 +107,7 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): gpxq = self.class_implementation( - module, - name, - # num_blocks=self.num_blocks, - act_order=self.act_order, - parallel_layers=parallel_layers) + module, name, act_order=self.act_order, parallel_layers=parallel_layers) hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) self.gpxq_layers[name] = gpxq @@ -139,6 +138,7 @@ def update(self): def catch_stopfwd(self, *args, **kwargs): pass + class gptq_mode(gpxq_mode): """ Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. @@ -185,6 +185,7 @@ def catch_stopfwd(self, *args, **kwargs): except StopFwdException: pass + class gpfq_mode(gpxq_mode): """ Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. @@ -245,7 +246,7 @@ def catch_stopfwd(self, *args, **kwargs): for module in self.model.modules(): if hasattr(module, 'weight_orig_data'): module.weight.data = quant_weight[module] - # Re-enable quantization. If activation quantization is disabled, + # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) if self.use_quant_activations: @@ -276,7 +277,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] self.parallel_layers = parallel_layers - + def process_input(self, inp): # Input is a tuple, so we take first element inp = inp[0] @@ -299,6 +300,7 @@ def process_input(self, inp): @abstractmethod def update_batch(self, module, input, current_layer): pass + @abstractmethod def single_layer_update(self, percdamp=.01): pass @@ -307,7 +309,7 @@ def get_quant_weights(self, i, permutation_list, i1=0): # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation (e.g., GPTQ) - i = i1+i + i = i1 + i # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ @@ -562,12 +564,11 @@ def single_layer_update(self, percdamp=.01): weight[:, perm[i2:]] -= (error_block.matmul(h_inv[0, i1:i2, i2:])).to(dtype) - - class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.layer = layer self.name = name @@ -587,7 +588,6 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] - # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.nsamples = 0 self.parallel_layers = parallel_layers @@ -596,9 +596,8 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.index_computed = False self.p = 0.25 - def update_batch(self, module, input, current_layer): - + # Update reference to current layer current_layer.layer_names.add(self.name) is_quant_disabled = module.weight_quant.disable_quant @@ -695,7 +694,8 @@ def single_layer_update(self): for i in range(self.groups): self.quant_input = self.quant_input[i] self.float_input = self.float_input[i] - U[i:i+1] += weight[i:i+1, t].unsqueeze(1) * self.float_input[i, t].unsqueeze(0) + U[i:i + + 1] += weight[i:i + 1, t].unsqueeze(1) * self.float_input[i, t].unsqueeze(0) norm = torch.linalg.norm(self.quant_input[:, t], 2) ** 2 if norm > 0: q_arg = U.matmul(self.quant_input[:, t]) / norm @@ -705,7 +705,8 @@ def single_layer_update(self): weight[:, t] = q_arg q = self.get_quant_weights(t, [torch.tensor(range(weight.shape[1]))]) for i in range(self.groups): - U[i:i+1] -= q[i:i+1].unsqueeze(1) * self.quant_input[i:i+1, t].unsqueeze(0) + U[i:i + + 1] -= q[i:i + 1].unsqueeze(1) * self.quant_input[i:i + 1, t].unsqueeze(0) del self.float_input del self.quant_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index c341a6adb..c228a5037 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy -from brevitas.graph.gpxq import gptq_mode, gpfq_mode import torch import torch.backends.cudnn as cudnn @@ -13,6 +12,8 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.gpxq import gpfq_mode +from brevitas.graph.gpxq import gptq_mode from brevitas.graph.quantize import COMPUTE_LAYER_MAP from brevitas.graph.quantize import LAYERWISE_COMPUTE_LAYER_MAP from brevitas.graph.quantize import layerwise_quantize @@ -256,7 +257,7 @@ def apply_gptq(calib_loader, model, act_order=False): dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq: + with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq: gptq_model = gptq.model for i in tqdm(range(gptq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 6318d9c28..082733a88 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -20,8 +20,9 @@ from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization, apply_gpfq +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate @@ -154,7 +155,7 @@ default=True, help='Narrow range for weight quantization (default: enabled)') add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') -add_bool_arg(parser, 'gpfq', default=False, help='GPTQ (default: disabled)') +add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') @@ -287,7 +288,7 @@ def main(): if args.gptq: print("Performing GPTQ:") apply_gptq(calib_loader, quant_model, args.gptq_act_order) - + if args.gpfq: print("Performing GPFQ:") apply_gpfq(calib_loader, quant_model)