From 0dc9268585cdf84e0d2cfdce6bb3c7d3f5d45c0d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 28 Jul 2023 10:38:33 +0200 Subject: [PATCH] Fix (gptq): fix for depthwise act_order (#688) --- src/brevitas/graph/gptq.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index a6c44f144..cd5129cc8 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -396,16 +396,18 @@ def single_layer_update(self, percdamp=.01): if self.groups > 1: # In case of depthwise convs, each weight matrix interacts with only # part of the input values, thus with only one of the hessian matrix - for ii in range(self.groups): + for ii, perm in enumerate(permutation_list): weight_block[ii, i:] -= error[ii] * h_inv_block[ii, i, i:] + # We need to update the original weights + weight[ii, perm[i1:i2][i:]] = weight_block[ii, i:].to(dtype) else: + perm = permutation_list[0] weight_block[:, i:] -= error.unsqueeze(1).matmul( h_inv_block[0, i, i:].unsqueeze(0)) + # We need to update the original weights + weight[:, perm[i1:i2][i:]] = weight_block[:, i:].to(dtype) error_block[:, i] = error - # We need to update the original weights - weight[:, perm[i1:i2][i:]] = weight_block[:, i:].to(dtype) - if self.groups > 1: # In case of depthwise convs, each weight matrix interacts with only # part of the input values, thus with only one of the hessian matrix