diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py new file mode 100644 index 000000000..b4fd66076 --- /dev/null +++ b/src/brevitas/graph/gpfq.py @@ -0,0 +1,235 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from copy import deepcopy +from typing import List, Optional + +import numpy as np +import torch +import unfoldNd + +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +import brevitas.nn as qnn + + +class gpfq_mode(gpxq_mode): + """ + Apply GPFQ algorithm. + + Args: + model (Module): The model to quantize with GPFQ + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPFQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + p: int = 0.25, + return_forward_output: bool = False, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + self.class_implementation = GPFQ + GPFQ.p = p + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Before collecting float input, restore original float weights if they have been modified + quant_weight = dict() + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + quant_weight[module] = deepcopy(module.weight.data) + module.weight.data = module.weight_orig_data + # Disable quantization + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Restore correct weights + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + module.weight.data = quant_weight[module] + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + +class GPFQ(GPxQ): + """ + Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main + """ + p = 0.25 + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + + if act_order: + raise ValueError("Act_order is not supported in GPFQ") + + super().__init__(layer, name, act_order, parallel_layers) + self.float_input = None + self.quantized_input = None + self.index_computed = False + self.p = GPFQ.p + + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + is_quant_disabled = module.weight_quant.disable_quant + + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + if is_quant_disabled: + if self.float_input is None: + self.float_input = inp_processed + else: + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + else: + if self.quantized_input is None: + self.quantized_input = inp_processed + else: + self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + # If we are executing GPFQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException + + def single_layer_update(self): + weight = self.layer.weight.data + dev = weight.device + dtype = weight.dtype + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + self.float_input = self.float_input.to(dev) + self.quantized_input = self.quantized_input.to(dev) + permutation_list = [torch.tensor(range(weight.shape[-1]))] + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + U[group_index] += torch.matmul( + weight[group_index, :, t].unsqueeze(1), + self.float_input[group_index, :, + t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + if norm > 0: + q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + + weight[group_index, :, t] = q_arg + q = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + U[group_index] -= torch.matmul( + q[group_index].unsqueeze(1), + self.quantized_input[group_index, :, t].unsqueeze(0)) + + del self.float_input + del self.quantized_input diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py new file mode 100644 index 000000000..32e3a7869 --- /dev/null +++ b/src/brevitas/graph/gptq.py @@ -0,0 +1,258 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from copy import deepcopy +import math +from typing import List, Optional, Set +import warnings + +import torch + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError +import unfoldNd + +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +import brevitas.nn as qnn + + +class gptq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + + Args: + model (Module): The model to quantize with GPTQ + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + num_blocks: int = 100, + return_forward_output: bool = False, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + self.class_implementation = GPTQ + GPTQ.num_blocks = num_blocks + + def catch_stopfwd(self, *args, **kwargs): + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + finally: + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + +class GPTQ(GPxQ): + """ + Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + + Copyright 2023 IST-DASLab + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + num_blocks = 100 + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + super().__init__(layer, name, act_order, parallel_layers) + + dev = self.layer.weight.device + + # Define how many columns to update in each mini-block + self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) + + # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse + self.H = torch.zeros((self.groups, self.columns, self.columns), + device=dev, + dtype=torch.float32) + self.nsamples = 0 + + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + inp = inp.t() + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + inp = unfold(inp) + inp = inp.transpose(1, 0) + inp = inp.flatten(1) + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + # Hessian computation + self.H *= self.nsamples / (self.nsamples + batch_size) + self.nsamples += batch_size + inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) + self.H += inp_processed.bmm(inp_processed.transpose(2, 1)) + # If we are executing GPTQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException + + def single_layer_update(self, percdamp=.01): + weight = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # List with permutation tensors for the Hessian and Weight matrix. + # If act_order is False, the tensors will be ordered indexes. + # For groupwise convolution, we have one tensor per group, + # thus len(permutation_list) is always equal to self.groups. + # We do not explicity permute the weight matrix, only the Hessian. + permutation_list = [] + weight = weight.view(self.groups, -1, weight.shape[-1]) + # For groupwise convolution, these operations are groupwise so we iterate + for i in range(self.groups): + # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding + # column in the weight matrix. + # The diagonal element is set to 1 to avoid division-by-zero + dead = torch.diag(self.H[i, :, :]) == 0 + self.H[i, dead, dead] = 1 + # If the diagonal of activations is zero, we set the weight to zero + weight[i, :, dead] = 0 + if self.act_order: + # Re-order Hessian so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) + self.H[i, :, :] = self.H[i, perm, :][:, perm] + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(self.H.shape[-1]), device=dev) + permutation_list.append(perm) + + # Try/Except in case the inverse Hessian cannot be computed + try: + for i in range(self.groups): + damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) + diag = torch.arange(self.columns, device=dev) + self.H[i, diag, diag] += damp + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) + self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + h_inv = self.H + except LinAlgError as e: + warnings.warn( + f'Failed to compute the inverse of the Hessian for layer {self.name} ' + f'GPTQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + error_block = torch.zeros_like( + weight[:, :, perm[i1:i2]], dtype=torch.float32) # [groups, OC/groups, i2-i1] + + h_inv_block = h_inv[:, i1:i2, i1:i2] + for i in range(count): + q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] + for group_index in range(self.groups): + perm = permutation_list[group_index] + q = q_groups[group_index] # [OC/groups] + w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] + d = h_inv_block[group_index, i, i] # [1] + error = (w - q) / d # [OC/groups] + error_block[group_index, :, i] = error + # We need to update the original weights + weight[group_index, :, perm[i1:i2][i:]] -= ( + error.unsqueeze(1).matmul(h_inv_block[group_index, i, + i:].unsqueeze(0))).to(dtype) + + for group_index in range(self.groups): + perm = permutation_list[group_index] + weight[group_index, :, perm[i2:]] -= ( + error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index c076b2827..9dc11d955 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -7,18 +7,12 @@ from dataclasses import dataclass from dataclasses import field from functools import partial -import math from operator import attrgetter from typing import List, Optional, Set import warnings -import torch - -try: - from torch.linalg import LinAlgError -except: - LinAlgError = RuntimeError import numpy as np +import torch import unfoldNd from brevitas.graph.calibrate import DisableEnableQuantization @@ -139,156 +133,6 @@ def catch_stopfwd(self, *args, **kwargs): pass -class gptq_mode(gpxq_mode): - """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. - - Args: - model (Module): The model to quantize with GPTQ - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - use_quant_activations: bool = True, - num_blocks: int = 100, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - use_quant_activations, - act_order, - return_forward_output) - - self.orig_forward = self.model.forward - self.model.forward = self.catch_stopfwd - # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks - self.class_implementation = GPTQ - GPTQ.num_blocks = num_blocks - - def catch_stopfwd(self, *args, **kwargs): - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - finally: - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - -class gpfq_mode(gpxq_mode): - """ - Apply GPFQ algorithm. - - Args: - model (Module): The model to quantize with GPFQ - inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPFQ. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: - >>> gpfq_model = gpfq.model - >>> for i in tqdm(range(gpfq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gpfq_model(img) - >>> gpfq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - use_quant_activations: bool = True, - p: int = 0.25, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - use_quant_activations, - act_order, - return_forward_output) - - self.orig_forward = self.model.forward - self.model.forward = self.catch_stopfwd - self.class_implementation = GPFQ - GPFQ.p = p - - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - # Before collecting float input, restore original float weights if they have been modified - quant_weight = dict() - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - quant_weight[module] = deepcopy(module.weight.data) - module.weight.data = module.weight_orig_data - # Disable quantization - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - # Restore correct weights - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - module.weight.data = quant_weight[module] - # Re-enable quantization. If activation quantization is disabled, - # we also disable bias quantization - self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) - if self.use_quant_activations: - self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) - else: - self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) - - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - class GPxQ(ABC): def __init__(self, layer, name, act_order, parallel_layers=1) -> None: @@ -401,305 +245,3 @@ def get_quant_weights(self, i, i1, permutation_list): # We need to remove the last dim q = q.squeeze(2) # [groups, OC/groups] or [1, OC] return q - - -class GPTQ(GPxQ): - """ - Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: - - Copyright 2023 IST-DASLab - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - num_blocks = 100 - - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: - super().__init__(layer, name, act_order, parallel_layers) - - dev = self.layer.weight.device - - # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) - - # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse - self.H = torch.zeros((self.groups, self.columns, self.columns), - device=dev, - dtype=torch.float32) - self.nsamples = 0 - - def update_batch(self, module, input, current_layer): - if self.disable_pre_forward_hook: - return input - - # Update reference to current layer - current_layer.layer_names.add(self.name) - inp = self.process_input(input) - batch_size = inp.shape[0] - - # Preprocess the input to compute the Hessian - if isinstance(self.layer, qnn.QuantLinear): - if len(inp.shape) > 2: - inp = inp.reshape((-1, sum(inp.shape[2:]))) - inp = inp.t() - # For QuantLinear layer, groups will be 1 - inp_processed = inp.unsqueeze(0) - - if isinstance(self.layer, SUPPORTED_CONV_OP): - # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - unfold_impl = unfoldNd.UnfoldTransposeNd - else: - unfold_impl = unfoldNd.UnfoldNd - - unfold = unfold_impl( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride) - - # Split input based on how many groups in convolution - inp_by_group = torch.chunk(inp, self.groups, 1) - inp_processed = [] - # Preprocess input by group - for i, inp in enumerate(inp_by_group): - inp = unfold(inp) - inp = inp.transpose(1, 0) - inp = inp.flatten(1) - inp_processed.append(inp) - inp_processed = torch.stack(inp_processed) - - # Hessian computation - self.H *= self.nsamples / (self.nsamples + batch_size) - self.nsamples += batch_size - inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) - self.H += inp_processed.bmm(inp_processed.transpose(2, 1)) - # If we are executing GPTQ with group of parallel layers, we keep track of how many forward - # we executed. Once we executed as many as the number of parallel_layers, we raise - # StopFwdException - current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): - current_layer.forward_count = 0 - raise StopFwdException - - def single_layer_update(self, percdamp=.01): - weight = self.layer.weight.data - dev = weight.device - - # Store the original dtype of the weights - # During computation, everything is converted to float32. - # When the weights are updated, we cast everything back to the original dtype - dtype = weight.dtype - - if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) - - # List with permutation tensors for the Hessian and Weight matrix. - # If act_order is False, the tensors will be ordered indexes. - # For groupwise convolution, we have one tensor per group, - # thus len(permutation_list) is always equal to self.groups. - # We do not explicity permute the weight matrix, only the Hessian. - permutation_list = [] - weight = weight.view(self.groups, -1, weight.shape[-1]) - # For groupwise convolution, these operations are groupwise so we iterate - for i in range(self.groups): - # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding - # column in the weight matrix. - # The diagonal element is set to 1 to avoid division-by-zero - dead = torch.diag(self.H[i, :, :]) == 0 - self.H[i, dead, dead] = 1 - # If the diagonal of activations is zero, we set the weight to zero - weight[i, :, dead] = 0 - if self.act_order: - # Re-order Hessian so that weights associated to - # higher magnitude activations are quantized first - perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) - self.H[i, :, :] = self.H[i, perm, :][:, perm] - else: - # No permutation, permutation tensor is a ordered index - perm = torch.tensor(range(self.H.shape[-1]), device=dev) - permutation_list.append(perm) - - # Try/Except in case the inverse Hessian cannot be computed - try: - for i in range(self.groups): - damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) - diag = torch.arange(self.columns, device=dev) - self.H[i, diag, diag] += damp - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) - self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) - h_inv = self.H - except LinAlgError as e: - warnings.warn( - f'Failed to compute the inverse of the Hessian for layer {self.name} ' - f'GPTQ will not be applied. ' - f'Increasing the number of samples might fix this issue') - return - finally: - del self.H - - for i1 in range(0, self.columns, self.blocksize): - i2 = min(i1 + self.blocksize, self.columns) - count = i2 - i1 - error_block = torch.zeros_like( - weight[:, :, perm[i1:i2]], dtype=torch.float32) # [groups, OC/groups, i2-i1] - - h_inv_block = h_inv[:, i1:i2, i1:i2] - for i in range(count): - q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] - for group_index in range(self.groups): - perm = permutation_list[group_index] - q = q_groups[group_index] # [OC/groups] - w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] - d = h_inv_block[group_index, i, i] # [1] - error = (w - q) / d # [OC/groups] - error_block[group_index, :, i] = error - # We need to update the original weights - weight[group_index, :, perm[i1:i2][i:]] -= ( - error.unsqueeze(1).matmul(h_inv_block[group_index, i, - i:].unsqueeze(0))).to(dtype) - - for group_index in range(self.groups): - perm = permutation_list[group_index] - weight[group_index, :, perm[i2:]] -= ( - error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) - - -class GPFQ(GPxQ): - """ - Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main - """ - p = 0.25 - - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: - - if act_order: - raise ValueError("Act_order is not supported in GPFQ") - - super().__init__(layer, name, act_order, parallel_layers) - self.float_input = None - self.quantized_input = None - self.index_computed = False - self.p = GPFQ.p - - def update_batch(self, module, input, current_layer): - if self.disable_pre_forward_hook: - return input - - # Update reference to current layer - current_layer.layer_names.add(self.name) - is_quant_disabled = module.weight_quant.disable_quant - - inp = self.process_input(input) - batch_size = inp.shape[0] - - # Preprocess the input to compute the Hessian - if isinstance(self.layer, qnn.QuantLinear): - if len(inp.shape) > 2: - inp = inp.reshape((-1, sum(inp.shape[2:]))) - # For QuantLinear layer, groups will be 1 - inp_processed = inp.unsqueeze(0) - - if isinstance(self.layer, SUPPORTED_CONV_OP): - # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - unfold_impl = unfoldNd.UnfoldTransposeNd - else: - unfold_impl = unfoldNd.UnfoldNd - - unfold = unfold_impl( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.kernel_size) - - # Split input based on how many groups in convolution - inp_by_group = torch.chunk(inp, self.groups, 1) - inp_processed = [] - # Preprocess input by group - for i, inp in enumerate(inp_by_group): - - inp = unfold(inp) - - batch_size, num_blocks = inp.shape[0], inp.shape[-1] - inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) - inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) - - if not self.index_computed: - self.index_computed = True - self.rand_indices = np.concatenate([ - np.random.choice( - np.arange(num_blocks * i, num_blocks * (i + 1)), - size=int( - self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) - for i in range(batch_size)]) # need to define self.p (probability) - - indexes = self.rand_indices - if np.max(self.rand_indices) > inp.shape[0]: - indexes = self.rand_indices < inp.shape[0] - indexes = self.rand_indices[indexes] - - inp = inp[indexes] - inp_processed.append(inp) - inp_processed = torch.stack(inp_processed) - - if is_quant_disabled: - if self.float_input is None: - self.float_input = inp_processed - else: - self.float_input = torch.cat([self.float_input, inp_processed], dim=1) - else: - if self.quantized_input is None: - self.quantized_input = inp_processed - else: - self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) - # If we are executing GPFQ with group of parallel layers, we keep track of how many forward - # we executed. Once we executed as many as the number of parallel_layers, we raise - # StopFwdException - current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): - current_layer.forward_count = 0 - raise StopFwdException - - def single_layer_update(self): - weight = self.layer.weight.data - dev = weight.device - dtype = weight.dtype - if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) - weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) - - for t in range(weight.shape[-1]): - for group_index in range(self.groups): - U[group_index] += weight[group_index, :, t].unsqueeze(1) * self.float_input[ - group_index, :, t].unsqueeze(0) #[OC/Groups, 1] * [1, INSHAPE[1]] - norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 - if norm > 0: - q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm - else: - q_arg = torch.zeros_like(U[group_index, :, 0]) - - weight[group_index, :, t] = q_arg - q = self.get_quant_weights(t, 0, [torch.tensor(range(weight.shape[-1]))]) - for group_index in range(self.groups): - U[group_index] -= q[group_index].unsqueeze(1) * self.quantized_input[ - group_index, :, t].unsqueeze(0) - - del self.float_input - del self.quantized_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 7e4afc0f8..6e60ecb09 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -13,8 +13,8 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode -from brevitas.graph.gpxq import gpfq_mode -from brevitas.graph.gpxq import gptq_mode +from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index ab5d78195..2e73bdf76 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -5,7 +5,7 @@ import torch -from brevitas.graph.gpxq import gptq_mode +from brevitas.graph.gptq import gptq_mode from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn