Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (GPxQ): unwrap QuantTensor when dealing with QuantLinear #915

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -168,6 +169,9 @@ def update_batch(self, module, input, current_layer):
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# Unwrap tensor value if quantized input
if isinstance(inp, QuantTensor):
inp = inp.value
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor


class gptq_mode(gpxq_mode):
Expand Down Expand Up @@ -150,6 +151,9 @@ def update_batch(self, module, input, current_layer):
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# Unwrap tensor value if QuantTensor
if isinstance(inp, QuantTensor):
inp = inp.value
inp = inp.t()
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)
Expand Down
22 changes: 21 additions & 1 deletion tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,31 @@ def forward(self, x):
return QuantConvTransposeModel


@pytest_cases.fixture()
def quant_linear_model():

class QuantLinearModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear_0 = qnn.QuantLinear(
3, 16, True, input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
self.linear_1 = qnn.QuantLinear(16, 10, True)

def forward(self, x):
x = self.linear_0(x)
x = self.linear_1(x)
return x

return QuantLinearModel


list_of_quant_fixtures = [
'quant_conv_with_input_quant_model',
'quant_convdepthconv_model',
'quant_residual_model',
'quant_convtranspose_model']
'quant_convtranspose_model',
'quant_linear_model']

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)
8 changes: 5 additions & 3 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_toymodels(

model_class = toy_quant_model
model = model_class()
if 'mha' in test_id:
if 'mha' in test_id or 'linear' in test_id:
inp = torch.randn(32, *IN_SIZE_LINEAR[1:])
else:
inp = torch.randn(32, *IN_SIZE_CONV_SMALL[1:])
Expand All @@ -129,12 +129,14 @@ def test_toymodels(
act_order=act_order,
use_quant_activations=use_quant_activations)

elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or
filter_func_str == 'identity'):
elif (name == 'gpfq') and (acc_bit_width < 32) and (
not use_quant_activations or filter_func_str
== 'identity') and not (hasattr(model, 'linear_0') and use_quant_activations):
# GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will
# raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will
# happen when `use_quant_activations=False` or when the input to a model is not quantized
# and `a2q_layer_filter_fnc` does not properly handle it.
# Note: the quant_linear_model actually is not expected to raise this Error since it has an input quant, i.e. manual check to avoid running in here with that model
with pytest.raises(ValueError):
apply_gpxq(
calib_loader=calib_loader,
Expand Down
Loading