From b341bfa74e159a0e65db775b9e5c6931914a50cc Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Thu, 6 Jul 2023 12:29:32 +0200 Subject: [PATCH] Fix (ptq): conflicts between gptq and equalization (#656) --- src/brevitas/graph/gptq.py | 1 + src/brevitas/nn/equalized_layer.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index bc5ce7991..43d1e7b92 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -227,6 +227,7 @@ def update_batch(self, module, input, current_layer): # Define batch size before re-organizing the input if hasattr(inp, 'names') and 'N' in inp.names: batch_dim = inp.names.index('N') + inp.rename_(None) inp = inp.transpose(0, batch_dim) batch_size = inp.shape[0] diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index d20e78537..35f636604 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,7 +1,11 @@ +from inspect import signature + import torch from brevitas.nn.quant_mha import QuantMultiheadAttention +INPUT_NAMES = ['input', 'inp', 'query', 'x'] + class EqualizedModule(torch.nn.Module): @@ -11,9 +15,12 @@ def __init__(self, scale_module, layer) -> None: self.layer = layer def forward(self, *args, **kwargs): - kwargs.update(zip(self.layer.forward.__code__.co_varnames[1:], args)) + # Convert args + kwargs + defaults into kwargs + bound_arguments = signature(self.layer.forward).bind(*args, **kwargs) + bound_arguments.apply_defaults() + kwargs = bound_arguments.arguments - possible_input_kwargs = ['input', 'inp', 'query'] + possible_input_kwargs = INPUT_NAMES input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] x = kwargs[input_kwarg] out = x @@ -31,5 +38,6 @@ def forward(self, *args, **kwargs): if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)): kwargs['key'] = out kwargs['value'] = out - out = self.layer(**kwargs) + # We convert everything to args so that hooks can work correctly + out = self.layer(*kwargs.values()) return out