diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index ef8b3aee3..716f2fb04 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -373,7 +373,7 @@ def single_layer_update(self, percdamp=.01): perm = permutation_list[group_index] q = q_groups[group_index] # [OC/groups] w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] - d = h_inv[group_index, i, i] # [1] + d = h_inv_block[group_index, i, i] # [1] error = (w - q) / d # [OC/groups] error_block[group_index, :, i] = error # We need to update the original weights