diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 55f787e26..5797d0083 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -381,16 +381,18 @@ def single_layer_update(self, percdamp=.01): if self.groups > 1: # In case of depthwise convs, each weight matrix interacts with only # part of the input values, thus with only one of the hessian matrix - for ii in range(self.groups): + for ii, perm in enumerate(permutation_list): weight_block[ii, i:] -= error[ii] * h_inv_block[ii, i, i:] + # We need to update the original weights + weight[ii, perm[i1:i2][i:]] = weight_block[ii, i:].to(dtype) else: + perm = permutation_list[0] weight_block[:, i:] -= error.unsqueeze(1).matmul( h_inv_block[0, i, i:].unsqueeze(0)) + # We need to update the original weights + weight[:, perm[i1:i2][i:]] = weight_block[:, i:].to(dtype) error_block[:, i] = error - # We need to update the original weights - weight[:, perm[i1:i2][i:]] = weight_block[:, i:].to(dtype) - if self.groups > 1: # In case of depthwise convs, each weight matrix interacts with only # part of the input values, thus with only one of the hessian matrix