From a7b9b017c4eaee6c6d150d4c24915defbef38712 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 16 May 2024 09:51:38 +0100 Subject: [PATCH 1/4] Feat (gpfq): separate float and quant forward pass for speedup --- src/brevitas/graph/gpfq.py | 143 +++++++++++++++--- src/brevitas/graph/gptq.py | 4 + src/brevitas/graph/gpxq.py | 27 ++-- .../imagenet_classification/ptq/ptq_common.py | 14 +- .../ptq/ptq_evaluate.py | 11 +- 5 files changed, 166 insertions(+), 33 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index fd7df9223..fe007bad6 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -3,13 +3,10 @@ from copy import deepcopy import math -from math import pi from typing import Callable, List, Optional import numpy as np import torch -from torch.fft import fft -from torch.fft import fftn import torch.nn as nn import unfoldNd @@ -87,7 +84,8 @@ def __init__( use_gpfa2q: bool = False, accumulator_bit_width: Optional[int] = None, a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True, - compression_rate: Optional[float] = 0.0) -> None: + compression_rate: Optional[float] = 0.0, + collect_float_first: bool = False) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -111,22 +109,34 @@ def __init__( if self.compression_rate < 0.0 or self.compression_rate > 1.0: raise ValueError('Compression rate for random projection must be between 0 and 1.') - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass + # speeding up by collecting float input first so we don't need to do it later + self.collect_float_first = collect_float_first + + def __enter__(self): + # initialize gpxq layers + self.setup_gpxq_layers() + if self.collect_float_first: + self.float_collection_hooks = dict() + # set up hooks for collecting the float input + for name, layer in self.gpxq_layers.items(): + # Attach float collecting hook + self.float_collection_hooks[name] = layer.layer.register_forward_hook( + layer.collect_float_input) + + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + + return self + else: + # if we're not collecting, setup original hooks + return self.setup_gpxq_hooks() - # Disable quantization - self.return_quant_tensor_state = disable_return_quant_tensor(self.model) - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass + def finalize_float_collection(self): + # remove the hooks we attached during the float collection + for name, hook in self.float_collection_hooks.items(): + hook.remove() # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization @@ -137,6 +147,37 @@ def catch_stopfwd(self, *args, **kwargs): self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + # setup the original hooks + self.setup_gpxq_hooks() + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + if not self.collect_float_first: + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization( + self.model, is_training=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + if self.return_forward_output: # If we want to return the output of the network, we need to disable all hooks for name, gpxq_class in self.gpxq_layers.items(): @@ -186,6 +227,70 @@ def __init__( self.p = p self.compression_rate = compression_rate + def collect_float_input(self, module, args, output): + # this is the hook function to offload the output of this layer to disc + inp = self.process_input(args) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + inp_processed = inp_processed.cpu() + + if self.float_input is None: + self.float_input = inp_processed + else: + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: return input diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..6661be97f 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -76,6 +76,10 @@ def __init__( # How many subblock to use during GPTQ for each layer self.num_blocks = num_blocks + def __enter__(self): + self.setup_gpxq_layers() + return self.setup_gpxq_hooks() + def catch_stopfwd(self, *args, **kwargs): try: self.orig_forward(*args, **kwargs) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index fdbaee52f..dcd193c08 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -113,34 +113,38 @@ def _is_module_supported(self, module): else: return False + @abstractmethod def __enter__(self): + pass + + def setup_gpxq_layers(self): # The user can specify on which layers to apply gptq in parallel. # All the others will be executed sequentially - dict_of_layers = { + self.dict_of_layers = { name: [(name, module)] for name, module in self.model.named_modules() if self._is_module_supported(module)} if self.group_of_parallel_layers is not None: for parallel_layers in self.group_of_parallel_layers: for name in parallel_layers: - if name not in dict_of_layers: + if name not in self.dict_of_layers: raise ValueError( "The layer {} is not present in the model or it is not supported for GPTQ" .format(name)) - del dict_of_layers[name] + del self.dict_of_layers[name] names = '_'.join(parallel_layers) - dict_of_layers[names] = [ + self.dict_of_layers[names] = [ (name, attrgetter(name)(self.model)) for name in parallel_layers] # Print warning if hooks are attached to any module, since the normal forward flow of the # network is highly disrupted during GPxQ - for _, parallel_layers in dict_of_layers.items(): + for _, parallel_layers in self.dict_of_layers.items(): for name, module in parallel_layers: if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks): warnings.warn( f'Hooks detected during setup for GPxQ. ' f'Behaviour might deviate from what expected.') - # Attach hooks for GPTQ + # initialize GPxQ if self._is_module_supported(module): gpxq_module_optimizer = self.initialize_module_optimizer( module, @@ -148,11 +152,14 @@ def __enter__(self): act_order=self.act_order, len_parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) - hook_fn = partial( - gpxq_module_optimizer.update_batch, current_layer=self.current_layer) - self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) self.gpxq_layers[name] = gpxq_module_optimizer + def setup_gpxq_hooks(self): + for name, module in self.gpxq_layers.items(): + # Attach hooks for GPxQ + hook_fn = partial(module.update_batch, current_layer=self.current_layer) + self.hook_dict[name] = module.layer.register_forward_pre_hook(hook_fn) + if not self.use_quant_activations: self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.disable_act_quantization( @@ -160,7 +167,7 @@ def __enter__(self): self.disable_quant_inference.disable_bias_quantization( self.model, is_training=self.model.training) - self.num_layers = len(dict_of_layers) + self.num_layers = len(self.dict_of_layers) return self def __exit__(self, type, value, traceback): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 9d94df12f..f4df00b0a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -543,7 +543,8 @@ def apply_gpfq( p=1.0, use_gpfa2q=False, accumulator_bit_width=None, - compression_rate=0.0): + compression_rate=0.0, + collect_float_first=True): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device @@ -554,7 +555,16 @@ def apply_gpfq( act_order=act_order, use_gpfa2q=use_gpfa2q, accumulator_bit_width=accumulator_bit_width, - compression_rate=compression_rate) as gpfq: + compression_rate=compression_rate, + collect_float_first=collect_float_first) as gpfq: + if collect_float_first: + print('Collecting float input first...') + for i, (images, target) in tqdm(enumerate(calib_loader)): + images = images.to(device) + images = images.to(dtype) + gpfq.orig_forward(images) + gpfq.finalize_float_collection() + gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7e2bf6ee5..eed92b6ec 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -245,6 +245,11 @@ def parse_type(v, default_type): type=float, help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.' ) +add_bool_arg( + parser, + 'collect-float-first', + default=False, + help='In GPFQ, separate float and quant forward pass for speed up. (default: False)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -437,7 +442,8 @@ def main(): quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order, - compression_rate=args.compression_rate) + compression_rate=args.compression_rate, + collect_float_first=args.collect_float_first) if args.gpfa2q: print("Performing GPFA2Q:") @@ -448,7 +454,8 @@ def main(): act_order=args.gpxq_act_order, use_gpfa2q=args.gpfa2q, accumulator_bit_width=args.accumulator_bit_width, - compression_rate=args.compression_rate) + compression_rate=args.compression_rate, + collect_float_first=args.collect_float_first) if args.gptq: print("Performing GPTQ:") From 1e05f987a511bb9faf8726610baf11b5d45c8d70 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 3 Jun 2024 14:50:05 +0100 Subject: [PATCH 2/4] Feat (gpfq): adding example code --- src/brevitas/graph/gpfq.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index fe007bad6..f9ba48654 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -62,7 +62,12 @@ class gpfq_mode(gpxq_mode): Example: >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: + >>> with gpfq_mode(model, collect_float_first) as gpfq: + >>> if collect_float_first: + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq.orig_forward(img) + >>> gpfq.finalize_float_collection() >>> gpfq_model = gpfq.model >>> for i in tqdm(range(gpfq.num_layers)): >>> for img, t in calib_loader: From bbc025c7450ebe5f56e1eadb3ae5087bf36d1681 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Tue, 18 Jun 2024 12:45:31 +0100 Subject: [PATCH 3/4] Feat (GPFQ): offload float input to disc --- src/brevitas/graph/gpfq.py | 52 ++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index f9ba48654..7eefabbd2 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -3,8 +3,11 @@ from copy import deepcopy import math +from tempfile import TemporaryDirectory from typing import Callable, List, Optional +from accelerate.utils.offload import offload_state_dict +from accelerate.utils.offload import OffloadedWeightsLoader import numpy as np import torch import torch.nn as nn @@ -122,7 +125,7 @@ def __enter__(self): self.setup_gpxq_layers() if self.collect_float_first: self.float_collection_hooks = dict() - # set up hooks for collecting the float input + # set up hooks for collecting the float input and storing them on disc for name, layer in self.gpxq_layers.items(): # Attach float collecting hook self.float_collection_hooks[name] = layer.layer.register_forward_hook( @@ -143,6 +146,13 @@ def finalize_float_collection(self): for name, hook in self.float_collection_hooks.items(): hook.remove() + # create temp dir + self.tmp_dir = TemporaryDirectory() + + # save all float activations to disc and delete them in the layers + for name, layer in self.gpxq_layers.items(): + layer.offload_float_input(tmp_dir=self.tmp_dir.name) + # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) @@ -155,6 +165,12 @@ def finalize_float_collection(self): # setup the original hooks self.setup_gpxq_hooks() + def __exit__(self, type, value, traceback): + # delete tmp dir + if self.collect_float_first: + self.tmp_dir.cleanup() + return super().__exit__(type, value, traceback) + def catch_stopfwd(self, *args, **kwargs): # Collect quant input try: @@ -202,7 +218,8 @@ def initialize_module_optimizer( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=self.p, - compression_rate=self.compression_rate) + compression_rate=self.compression_rate, + collect_float_first=self.collect_float_first) else: return GPFA2Q( layer=layer, @@ -212,7 +229,8 @@ def initialize_module_optimizer( create_weight_orig=create_weight_orig, p=self.p, accumulator_bit_width=self.accumulator_bit_width, - compression_rate=self.compression_rate) + compression_rate=self.compression_rate, + collect_float_first=self.collect_float_first) class GPFQ(GPxQ): @@ -221,8 +239,15 @@ class GPFQ(GPxQ): """ def __init__( - self, layer, name, act_order, len_parallel_layers, create_weight_orig, p, - compression_rate) -> None: + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + p, + compression_rate, + collect_float_first) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) @@ -231,6 +256,7 @@ def __init__( self.index_computed = False self.p = p self.compression_rate = compression_rate + self.collect_float_first = collect_float_first def collect_float_input(self, module, args, output): # this is the hook function to offload the output of this layer to disc @@ -296,6 +322,14 @@ def collect_float_input(self, module, args, output): else: self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + def offload_float_input(self, tmp_dir): + # create tmp directory for this layer + self.save_dir = tmp_dir + '/' + self.name + # method expects dict + offload_state_dict(self.save_dir, state_dict={'float_input': self.float_input.detach()}) + # then delete float_input to save memory + del self.float_input + def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: return input @@ -382,6 +416,10 @@ def single_layer_update(self): weight = self.layer.weight.data dev = weight.device dtype = weight.dtype + # load float input from disc if needed + if self.collect_float_first: + # load float_input from disc + self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input'] if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance( self.layer, @@ -469,6 +507,10 @@ def single_layer_update(self): weight = self.layer.weight.data dev = weight.device dtype = weight.dtype + # load float input from disc if needed + if self.collect_float_first: + # load float_input from disc + self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input'] if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view From 9f6ade9265bf692f2db2eaba87e3cef12c45d97b Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 20 Jun 2024 11:15:53 +0100 Subject: [PATCH 4/4] Fix (GPFQ): change offloading to use torch --- src/brevitas/graph/gpfq.py | 39 +++++++++++++------ src/brevitas/graph/gptq.py | 7 ++++ src/brevitas/graph/gpxq.py | 6 --- .../imagenet_classification/ptq/ptq_common.py | 4 +- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 7eefabbd2..d04811871 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -3,17 +3,18 @@ from copy import deepcopy import math +import os from tempfile import TemporaryDirectory from typing import Callable, List, Optional -from accelerate.utils.offload import offload_state_dict -from accelerate.utils.offload import OffloadedWeightsLoader import numpy as np import torch +from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn import unfoldNd from brevitas.function import get_upper_bound_on_l1_norm +from brevitas.fx import GraphModule from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ @@ -66,12 +67,12 @@ class gpfq_mode(gpxq_mode): Example: >>> with torch.no_grad(): >>> with gpfq_mode(model, collect_float_first) as gpfq: + >>> gpfq_model = gpfq.model >>> if collect_float_first: >>> for img, t in calib_loader: >>> img = img.cuda() - >>> gpfq.orig_forward(img) + >>> gpfq_model(img) >>> gpfq.finalize_float_collection() - >>> gpfq_model = gpfq.model >>> for i in tqdm(range(gpfq.num_layers)): >>> for img, t in calib_loader: >>> img = img.cuda() @@ -139,6 +140,12 @@ def __enter__(self): return self else: # if we're not collecting, setup original hooks + # setup catch_stopfwd + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd return self.setup_gpxq_hooks() def finalize_float_collection(self): @@ -161,7 +168,12 @@ def finalize_float_collection(self): else: self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) - + # setup catch_stopfwd + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd # setup the original hooks self.setup_gpxq_hooks() @@ -325,8 +337,10 @@ def collect_float_input(self, module, args, output): def offload_float_input(self, tmp_dir): # create tmp directory for this layer self.save_dir = tmp_dir + '/' + self.name - # method expects dict - offload_state_dict(self.save_dir, state_dict={'float_input': self.float_input.detach()}) + os.makedirs(self.save_dir, exist_ok=True) + self.float_input_file = self.save_dir + '/float_input.pt' + # offload float input + torch.save(self.float_input, self.float_input_file) # then delete float_input to save memory del self.float_input @@ -418,8 +432,7 @@ def single_layer_update(self): dtype = weight.dtype # load float input from disc if needed if self.collect_float_first: - # load float_input from disc - self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input'] + self.float_input = torch.load(self.float_input_file) if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance( self.layer, @@ -484,7 +497,8 @@ def __init__( create_weight_orig, accumulator_bit_width, p, - compression_rate) -> None: + compression_rate, + collect_float_first) -> None: GPFQ.__init__( self, layer=layer, @@ -493,7 +507,8 @@ def __init__( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=p, - compression_rate=compression_rate) + compression_rate=compression_rate, + collect_float_first=collect_float_first) self.accumulator_bit_width = accumulator_bit_width assert self.accumulator_bit_width is not None @@ -510,7 +525,7 @@ def single_layer_update(self): # load float input from disc if needed if self.collect_float_first: # load float_input from disc - self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input'] + self.float_input = torch.load(self.float_input_file) if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 6661be97f..fad7dc0ff 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -13,9 +13,11 @@ from torch.linalg import LinAlgError except: LinAlgError = RuntimeError +from torch.fx import GraphModule as TorchGraphModule import unfoldNd from brevitas import torch_version +from brevitas.fx import GraphModule from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException @@ -77,6 +79,11 @@ def __init__( self.num_blocks = num_blocks def __enter__(self): + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd self.setup_gpxq_layers() return self.setup_gpxq_hooks() diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index dcd193c08..31407d7c3 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -99,12 +99,6 @@ def __init__( self.group_of_parallel_layers = group_of_parallel_layers self.return_forward_output = return_forward_output - self.orig_forward = self.model.forward - if isinstance(self.model, (GraphModule, TorchGraphModule)): - self.model.__class__.forward = self.catch_stopfwd - else: - self.model.forward = self.catch_stopfwd - def _is_module_supported(self, module): if isinstance(module, SUPPORTED_CONV_OP): return True diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index f4df00b0a..afa44e0ef 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -557,15 +557,15 @@ def apply_gpfq( accumulator_bit_width=accumulator_bit_width, compression_rate=compression_rate, collect_float_first=collect_float_first) as gpfq: + gpfq_model = gpfq.model if collect_float_first: print('Collecting float input first...') for i, (images, target) in tqdm(enumerate(calib_loader)): images = images.to(device) images = images.to(dtype) - gpfq.orig_forward(images) + gpfq_model(images) gpfq.finalize_float_collection() - gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): images = images.to(device)