Skip to content

Commit

Permalink
Torch compile support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 15, 2024
1 parent 337196a commit 5b138d3
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 61 deletions.
10 changes: 5 additions & 5 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(self, return_quant_tensor: bool):
def channelwise_separable(self) -> bool:
pass

def _set_global_is_quant_layer(self, value):
config._IS_INSIDE_QUANT_LAYER = value
# def _set_global_is_quant_layer(self, value):
# config._IS_INSIDE_QUANT_LAYER = value

def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
quant_tensor_classes = [IntQuantTensor, FloatQuantTensor]
Expand All @@ -78,23 +78,23 @@ def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]):
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and
all([isinstance(t, Tensor) for t in inp])):
qt_class = self.get_quant_tensor_class(inp)
if qt_class is not None:
inp = qt_class(*inp)
if not torch._C._get_tracing_state():
if not torch._C._get_tracing_state() and not torch.compiler.is_compiling():
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(False)
# self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised'
return quant_output
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def __init__(
proxy_prefix='bias_',
**kwargs)

def quant_bias(self):
if self.bias is None:
return None
quant_bias = self.bias_quant(self.bias)
return quant_bias
# def quant_bias(self):
# if self.bias is None:
# return None
# quant_bias = self.bias_quant(self.bias)
# return quant_bias

def register_parameter(self, name, value):
super(QuantBiasMixin, self).register_parameter(name, value)
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .utils import merge_bn
Expand Down Expand Up @@ -153,6 +153,10 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
# If bias has its own scale, we dequantize it, since it cannot be added to input @ weight anyway
# This simplifies control flow later for QuantTensors and torch.compile
if not self.bias_quant.requires_input_scale:
quant_bias = _unpack_quant_tensor(quant_bias)
else:
quant_bias = None
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)
Expand Down
21 changes: 11 additions & 10 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,18 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self._cached_bias = None
self.cache_inference_quant_bias = False
self.requires_input_scale = self.quant_injector.requires_input_scale and self.is_quant_enabled

@property
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def requires_input_scale(self) -> bool:
if self.is_quant_enabled:
return self.quant_injector.requires_input_scale
else:
return False
# @property
# def requires_input_scale and self.is_quant_enabled(self) -> bool:
# if self.is_quant_enabled:
# return self.quant_injector.requires_input_scale and self.is_quant_enabled
# else:
# return False

def get_cached(self, attr):
if self._cached_bias is None:
Expand Down Expand Up @@ -249,7 +250,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
def scale(self):
if not self.is_quant_enabled:
return None
if self.requires_input_scale:
if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled:
cache = self.get_cached('scale')
return cache
zhs = self._zero_hw_sentinel()
Expand Down Expand Up @@ -282,7 +283,7 @@ def compute_bias_scale(
self,
input: Optional[Union[Tensor, IntQuantTensor]],
weight: Optional[Union[Tensor, IntQuantTensor]]) -> Optional[Tensor]:
if not self.requires_input_scale:
if not self.requires_input_scale and self.is_quant_enabled:
return None
if not isinstance(input, IntQuantTensor) or not isinstance(weight, IntQuantTensor):
return None
Expand All @@ -305,12 +306,12 @@ def forward(
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
if self.requires_input_scale and self.is_quant_enabled and input_scale is None:
input_scale = self.scale()
if input_scale is None:
raise RuntimeError("Input scale required")

if self.requires_input_scale:
if self.requires_input_scale and self.is_quant_enabled:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
Expand Down
17 changes: 8 additions & 9 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas import config
from brevitas.common import ExportMixin
from brevitas.core.utils import StatelessBuffer
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject import BaseInjector as Injector
from brevitas.utils.quant_utils import float_to_int_impl_to_enum

Expand All @@ -29,15 +30,11 @@ def _is_groupwise(quant_injector):


def _is_signed(quant_injector):
if 'signed' in quant_injector:
return quant_injector.signed
return None
return quant_injector.signed


def _is_narrow_range(quant_injector):
if 'narrow_range' in quant_injector:
return quant_injector.narrow_range
return None
return quant_injector.narrow_range


def _rounding_mode(quant_injector):
Expand Down Expand Up @@ -88,6 +85,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.tracked_module_list = []
self.add_tracked_module(quant_layer)
self.disable_quant = False
self.is_zero_zero_point = self.quant_injector.zero_point_impl == ZeroZeroPoint if 'zero_point_impl' in quant_injector else False
self.is_signed = self.quant_injector.signed if 'signed' in quant_injector else None

@property
def requires_export_handler(self):
Expand All @@ -108,9 +107,9 @@ def init_tensor_quant(self):
def is_quant_enabled(self):
return not self.disable_quant and self.tensor_quant is not None

@property
def is_signed(self):
return _is_signed(self.quant_injector)
# @property
# def is_signed(self):
# return _is_signed(self.quant_injector)

@property
def is_groupwise(self):
Expand Down
19 changes: 13 additions & 6 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def retrieve_attribute(self, attribute, force_eval):
def is_quant_enabled(self):
return self._is_quant_enabled and not self.disable_quant

@property
def is_signed(self):
if self._cached_act is not None:
return self._cached_act.signed
return super().is_signed
# @property
# def is_signed(self):
# if self._cached_act is not None:
# return self._cached_act.signed
# try:
# return super().is_signed
# except:
# return None

@is_quant_enabled.setter
def is_quant_enabled(self, is_quant_enabled):
Expand Down Expand Up @@ -174,7 +177,11 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, IntQuantTensor
# If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
if isinstance(y, tuple) and not any(map(lambda f: f is None, y)):
out = IntQuantTensor(*y, signed=self.is_signed, training=self.training)
out = IntQuantTensor(
*y,
signed=self.is_signed,
training=self.training,
_zero_zero_point=self.is_zero_zero_point)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
Expand Down
25 changes: 19 additions & 6 deletions src/brevitas/quant_tensor/int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch._dynamo import allow_in_graph
from torch.utils._pytree import register_pytree_node

from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
Expand All @@ -20,19 +22,20 @@

class IntQuantTensor(IntQuantTensorBase, QuantTensor):

def __new__(cls, value, scale, zero_point, bit_width, signed, training):
def __new__(cls, value, scale, zero_point, bit_width, signed, training, _zero_zero_point=False):

if not isinstance(scale, torch.Tensor):
if not isinstance(scale, (torch.Tensor, torch.fx.proxy.Proxy)):
scale = torch.tensor(scale, dtype=torch.float)
if not isinstance(zero_point, torch.Tensor):
if not isinstance(zero_point, (torch.Tensor, torch.fx.proxy.Proxy)):
zero_point = torch.tensor(zero_point, dtype=torch.float)
if not isinstance(bit_width, torch.Tensor):
if not isinstance(bit_width, (torch.Tensor, torch.fx.proxy.Proxy)):
bit_width = torch.tensor(bit_width, dtype=torch.float)
if not isinstance(signed, torch.Tensor):
if not isinstance(signed, (torch.Tensor, torch.fx.proxy.Proxy)):
signed = torch.tensor(signed, dtype=torch.bool)
if not isinstance(training, torch.Tensor):
if not isinstance(training, (torch.Tensor, torch.fx.proxy.Proxy)):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training)
quant_tensor._zero_zero_point = _zero_zero_point
return quant_tensor

@property
Expand Down Expand Up @@ -362,3 +365,13 @@ def __abs__(self):

def __pos__(self):
return self


# def flatten(int_qt):
# return list(int_qt._fields.keys(), int_qt._fields.values())
# def unflatten(key, values):
# init_args = dict(zip(key, values))
# return IntQuantTensor(**init_args)

# register_pytree_node(IntQuantTensor, flatten_fn=flatten, unflatten_fn=unflatten)
allow_in_graph(IntQuantTensor)
20 changes: 11 additions & 9 deletions src/brevitas/quant_tensor/int_torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape):
return quant_input


def check_zp(quant_input, quant_weight):
if not quant_input._zero_zero_point or not quant_weight._zero_zero_point:
warnings.warn("Computing zero point of output accumulator not supported yet.")
return False
else:
return True


# from functorch.experimental.control_flow import cond
def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs):
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
Expand Down Expand Up @@ -179,9 +188,7 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs):

if bias is not None:
if output_scale is not None:
if (isinstance(bias, IntQuantTensor) and
not torch.allclose(bias.scale, output_scale)) or not isinstance(bias,
IntQuantTensor):
if not isinstance(bias, IntQuantTensor):
channel_dim = -1 if isinstance(fn, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
quant_input, channel_dim=channel_dim)
Expand All @@ -193,12 +200,7 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs):
output_bit_width = output_bit_width + 1

if compute_output_quant_tensor:
if (isinstance(quant_input, IntQuantTensor) and
(quant_input.zero_point != 0.0).any()) or (isinstance(quant_weight, IntQuantTensor) and
(quant_weight.zero_point != 0.0).any()):
warnings.warn("Computing zero point of output accumulator not supported yet.")
compute_output_quant_tensor = False

compute_output_quant_tensor = check_zp(quant_input, quant_weight)
if compute_output_quant_tensor:
if output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def main():
apply_bias_correction(calib_loader, quant_model)
if args.compile and TORCH_GEQ_200:
print("Applying torch.compile")
model = torch.compile(model)

quant_model = torch.compile(quant_model, fullgraph=True)
# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def model_eval(model, valenc, seqlen):
dev = next(iter(model.parameters())).device
with torch.no_grad():
nlls = []
for inps in valenc:
for inps in tqdm(valenc):
inps = {k: v.to(dev) for k, v in inps.items()}
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inps['input_ids'][:, 1:].to(dev)
Expand Down
10 changes: 9 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,20 @@ def main():

if args.compile:
print("Applying compile")
model = torch.compile(model)
remove_hooks(model)
model = torch.compile(model, fullgraph=True)
offload_model(model)
# Pre-compile
with torch.no_grad():
model(**val_data[0])

import time
start = time.time()
if args.eval:
print("Model eval...")
ppl = model_eval(model, val_data, args.seqlen)
print(f"C4 perplexity: {ppl}")
print(time.time() - start)
remove_hooks(model)

if args.export_target:
Expand Down
Loading

0 comments on commit 5b138d3

Please sign in to comment.