diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 75992e8fe..d8b436fc1 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -273,7 +273,7 @@ def get_quant_weights(self, i, i1, permutation_list): index = permutation_list[0][i] q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze( 0) # [1, OC, 1] - q = q[:, :, i:i + 1] # [groups, OC/groups, 1] + q = q[:, :, index:index + 1] # [groups, OC/groups, 1] else: index = permutation_list[0][i] subtensor_slice_list = [None, (index, index + 1)]