Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (gpfq): separate float and quant forward pass for speedup #955

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 193 additions & 26 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@

from copy import deepcopy
import math
from math import pi
import os
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional

import numpy as np
import torch
from torch.fft import fft
from torch.fft import fftn
from torch.fx import GraphModule as TorchGraphModule
import torch.nn as nn
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
from brevitas.fx import GraphModule
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpxq import GPxQ
Expand Down Expand Up @@ -65,8 +66,13 @@ class gpfq_mode(gpxq_mode):

Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> with gpfq_mode(model, collect_float_first) as gpfq:
>>> gpfq_model = gpfq.model
>>> if collect_float_first:
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.finalize_float_collection()
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
Expand All @@ -87,7 +93,8 @@ def __init__(
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True,
compression_rate: Optional[float] = 0.0) -> None:
compression_rate: Optional[float] = 0.0,
collect_float_first: bool = False) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -111,22 +118,47 @@ def __init__(
if self.compression_rate < 0.0 or self.compression_rate > 1.0:
raise ValueError('Compression rate for random projection must be between 0 and 1.')

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass
# speeding up by collecting float input first so we don't need to do it later
self.collect_float_first = collect_float_first

def __enter__(self):
# initialize gpxq layers
self.setup_gpxq_layers()
if self.collect_float_first:
self.float_collection_hooks = dict()
# set up hooks for collecting the float input and storing them on disc
for name, layer in self.gpxq_layers.items():
# Attach float collecting hook
self.float_collection_hooks[name] = layer.layer.register_forward_hook(
layer.collect_float_input)

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)

return self
else:
# if we're not collecting, setup original hooks
# setup catch_stopfwd
self.orig_forward = self.model.forward
if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.catch_stopfwd
else:
self.model.forward = self.catch_stopfwd
return self.setup_gpxq_hooks()

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass
def finalize_float_collection(self):
# remove the hooks we attached during the float collection
for name, hook in self.float_collection_hooks.items():
hook.remove()

# create temp dir
self.tmp_dir = TemporaryDirectory()

# save all float activations to disc and delete them in the layers
for name, layer in self.gpxq_layers.items():
layer.offload_float_input(tmp_dir=self.tmp_dir.name)

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
Expand All @@ -136,6 +168,48 @@ def catch_stopfwd(self, *args, **kwargs):
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
# setup catch_stopfwd
self.orig_forward = self.model.forward
if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.catch_stopfwd
else:
self.model.forward = self.catch_stopfwd
# setup the original hooks
self.setup_gpxq_hooks()

def __exit__(self, type, value, traceback):
# delete tmp dir
if self.collect_float_first:
self.tmp_dir.cleanup()
return super().__exit__(type, value, traceback)

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

if not self.collect_float_first:
# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(
self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
Expand All @@ -156,7 +230,8 @@ def initialize_module_optimizer(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
compression_rate=self.compression_rate)
compression_rate=self.compression_rate,
collect_float_first=self.collect_float_first)
else:
return GPFA2Q(
layer=layer,
Expand All @@ -166,7 +241,8 @@ def initialize_module_optimizer(
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width,
compression_rate=self.compression_rate)
compression_rate=self.compression_rate,
collect_float_first=self.collect_float_first)


class GPFQ(GPxQ):
Expand All @@ -175,8 +251,15 @@ class GPFQ(GPxQ):
"""

def __init__(
self, layer, name, act_order, len_parallel_layers, create_weight_orig, p,
compression_rate) -> None:
self,
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
p,
compression_rate,
collect_float_first) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

Expand All @@ -185,6 +268,81 @@ def __init__(
self.index_computed = False
self.p = p
self.compression_rate = compression_rate
self.collect_float_first = collect_float_first

def collect_float_input(self, module, args, output):
# this is the hook function to offload the output of this layer to disc
inp = self.process_input(args)
batch_size = inp.shape[0]

# Preprocess the input to compute the Hessian
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd

unfold = unfold_impl(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.kernel_size)

# Split input based on how many groups in convolution
inp_by_group = torch.chunk(inp, self.groups, 1)
inp_processed = []
# Preprocess input by group
for i, inp in enumerate(inp_by_group):

inp = unfold(inp)

batch_size, num_blocks = inp.shape[0], inp.shape[-1]
inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1])
inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1])

if not self.index_computed:
self.index_computed = True
self.rand_indices = np.concatenate([
np.random.choice(
np.arange(num_blocks * i, num_blocks * (i + 1)),
size=int(
self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks))
for i in range(batch_size)]) # need to define self.p (probability)

indexes = self.rand_indices
if np.max(self.rand_indices) > inp.shape[0]:
indexes = self.rand_indices < inp.shape[0]
indexes = self.rand_indices[indexes]

inp = inp[indexes]
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

inp_processed = inp_processed.cpu()

if self.float_input is None:
self.float_input = inp_processed
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)

def offload_float_input(self, tmp_dir):
# create tmp directory for this layer
self.save_dir = tmp_dir + '/' + self.name
os.makedirs(self.save_dir, exist_ok=True)
self.float_input_file = self.save_dir + '/float_input.pt'
# offload float input
torch.save(self.float_input, self.float_input_file)
# then delete float_input to save memory
del self.float_input

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down Expand Up @@ -272,6 +430,9 @@ def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
# load float input from disc if needed
if self.collect_float_first:
self.float_input = torch.load(self.float_input_file)
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(
self.layer,
Expand Down Expand Up @@ -336,7 +497,8 @@ def __init__(
create_weight_orig,
accumulator_bit_width,
p,
compression_rate) -> None:
compression_rate,
collect_float_first) -> None:
GPFQ.__init__(
self,
layer=layer,
Expand All @@ -345,7 +507,8 @@ def __init__(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p,
compression_rate=compression_rate)
compression_rate=compression_rate,
collect_float_first=collect_float_first)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

Expand All @@ -359,6 +522,10 @@ def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
# load float input from disc if needed
if self.collect_float_first:
# load float_input from disc
self.float_input = torch.load(self.float_input_file)
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
Expand Down
11 changes: 11 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from torch.linalg import LinAlgError
except:
LinAlgError = RuntimeError
from torch.fx import GraphModule as TorchGraphModule
import unfoldNd

from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
Expand Down Expand Up @@ -76,6 +78,15 @@ def __init__(
# How many subblock to use during GPTQ for each layer
self.num_blocks = num_blocks

def __enter__(self):
self.orig_forward = self.model.forward
if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.catch_stopfwd
else:
self.model.forward = self.catch_stopfwd
self.setup_gpxq_layers()
return self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
try:
self.orig_forward(*args, **kwargs)
Expand Down
Loading
Loading