Skip to content

Commit

Permalink
Feat (gpfq): separate float and quant forward pass for speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jun 3, 2024
1 parent 02f5b6b commit bece0c6
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 34 deletions.
145 changes: 125 additions & 20 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@

from copy import deepcopy
import math
from math import pi
from typing import Callable, List, Optional

import numpy as np
import torch
from torch.fft import fft
from torch.fft import fftn
import torch.nn as nn
import unfoldNd

Expand Down Expand Up @@ -87,7 +84,8 @@ def __init__(
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True,
compression_rate: Optional[float] = 0.0) -> None:
compression_rate: Optional[float] = 0.0,
collect_float_first: bool = False) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -110,23 +108,35 @@ def __init__(
self.compression_rate = compression_rate
if self.compression_rate < 0.0 or self.compression_rate > 1.0:
raise ValueError('Compression rate for random projection must be between 0 and 1.')

# speeding up by collecting float input first so we don't need to do it later
self.collect_float_first = collect_float_first

def __enter__(self):
# initialize gpxq layers
self.setup_gpxq_layers()
if self.collect_float_first:
self.float_collection_hooks = dict()
# set up hooks for collecting the float input
for name, layer in self.gpxq_layers.items():
# Attach float collecting hook
self.float_collection_hooks[name] = layer.layer.register_forward_hook(
layer.collect_float_input)

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)

return self
else:
# if we're not collecting, setup original hooks
return self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass
def finalize_float_collection(self):
# remove the hooks we attached during the float collection
for name, hook in self.float_collection_hooks.items():
hook.remove()

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
Expand All @@ -137,6 +147,37 @@ def catch_stopfwd(self, *args, **kwargs):
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

# setup the original hooks
self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

if not self.collect_float_first:
# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(
self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
Expand Down Expand Up @@ -186,6 +227,70 @@ def __init__(
self.p = p
self.compression_rate = compression_rate

def collect_float_input(self, module, args, output):
# this is the hook function to offload the output of this layer to disc
inp = self.process_input(args)
batch_size = inp.shape[0]

# Preprocess the input to compute the Hessian
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd

unfold = unfold_impl(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.kernel_size)

# Split input based on how many groups in convolution
inp_by_group = torch.chunk(inp, self.groups, 1)
inp_processed = []
# Preprocess input by group
for i, inp in enumerate(inp_by_group):

inp = unfold(inp)

batch_size, num_blocks = inp.shape[0], inp.shape[-1]
inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1])
inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1])

if not self.index_computed:
self.index_computed = True
self.rand_indices = np.concatenate([
np.random.choice(
np.arange(num_blocks * i, num_blocks * (i + 1)),
size=int(
self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks))
for i in range(batch_size)]) # need to define self.p (probability)

indexes = self.rand_indices
if np.max(self.rand_indices) > inp.shape[0]:
indexes = self.rand_indices < inp.shape[0]
indexes = self.rand_indices[indexes]

inp = inp[indexes]
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

inp_processed = inp_processed.cpu()

if self.float_input is None:
self.float_input = inp_processed
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __init__(
# How many subblock to use during GPTQ for each layer
self.num_blocks = num_blocks

def __enter__(self):
self.setup_gpxq_layers()
return self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
try:
self.orig_forward(*args, **kwargs)
Expand Down
27 changes: 17 additions & 10 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,54 +113,61 @@ def _is_module_supported(self, module):
else:
return False

@abstractmethod
def __enter__(self):
pass

def setup_gpxq_layers(self):
# The user can specify on which layers to apply gptq in parallel.
# All the others will be executed sequentially
dict_of_layers = {
self.dict_of_layers = {
name: [(name, module)] for name,
module in self.model.named_modules() if self._is_module_supported(module)}
if self.group_of_parallel_layers is not None:
for parallel_layers in self.group_of_parallel_layers:
for name in parallel_layers:
if name not in dict_of_layers:
if name not in self.dict_of_layers:
raise ValueError(
"The layer {} is not present in the model or it is not supported for GPTQ"
.format(name))
del dict_of_layers[name]
del self.dict_of_layers[name]
names = '_'.join(parallel_layers)
dict_of_layers[names] = [
self.dict_of_layers[names] = [
(name, attrgetter(name)(self.model)) for name in parallel_layers]

# Print warning if hooks are attached to any module, since the normal forward flow of the
# network is highly disrupted during GPxQ
for _, parallel_layers in dict_of_layers.items():
for _, parallel_layers in self.dict_of_layers.items():
for name, module in parallel_layers:
if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks):
warnings.warn(
f'Hooks detected during setup for GPxQ. '
f'Behaviour might deviate from what expected.')

# Attach hooks for GPTQ
# initialize GPxQ
if self._is_module_supported(module):
gpxq_module_optimizer = self.initialize_module_optimizer(
module,
name,
act_order=self.act_order,
len_parallel_layers=len(parallel_layers),
create_weight_orig=self.create_weight_orig)
hook_fn = partial(
gpxq_module_optimizer.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
self.gpxq_layers[name] = gpxq_module_optimizer

def setup_gpxq_hooks(self):
for name, module in self.gpxq_layers.items():
# Attach hooks for GPxQ
hook_fn = partial(module.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.layer.register_forward_pre_hook(hook_fn)

if not self.use_quant_activations:
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_act_quantization(
self.model, is_training=self.model.training)
self.disable_quant_inference.disable_bias_quantization(
self.model, is_training=self.model.training)

self.num_layers = len(dict_of_layers)
self.num_layers = len(self.dict_of_layers)
return self

def __exit__(self, type, value, traceback):
Expand Down
14 changes: 12 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ def apply_gpfq(
p=1.0,
use_gpfa2q=False,
accumulator_bit_width=None,
compression_rate=0.0):
compression_rate=0.0,
collect_float_first=True):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand All @@ -554,7 +555,16 @@ def apply_gpfq(
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width,
compression_rate=compression_rate) as gpfq:
compression_rate=compression_rate,
collect_float_first=collect_float_first) as gpfq:
if collect_float_first:
print('Collecting float input first...')
for i, (images, target) in tqdm(enumerate(calib_loader)):
images = images.to(device)
images = images.to(dtype)
gpfq.orig_forward(images)
gpfq.finalize_float_collection()

gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def parse_type(v, default_type):
type=float,
help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.'
)
add_bool_arg(
parser,
'collect-float-first',
default=False,
help='In GPFQ, separate float and quant forward pass for speed up. (default: False)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand Down Expand Up @@ -437,7 +442,8 @@ def main():
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
compression_rate=args.compression_rate)
compression_rate=args.compression_rate,
collect_float_first=args.collect_float_first)

if args.gpfa2q:
print("Performing GPFA2Q:")
Expand All @@ -448,7 +454,8 @@ def main():
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width,
compression_rate=args.compression_rate)
compression_rate=args.compression_rate,
collect_float_first=args.collect_float_first)

if args.gptq:
print("Performing GPTQ:")
Expand Down

0 comments on commit bece0c6

Please sign in to comment.