diff --git a/examples/transformers/prune_hf_bert.py b/examples/transformers/prune_hf_bert.py index 70915279..1f5f0a9f 100644 --- a/examples/transformers/prune_hf_bert.py +++ b/examples/transformers/prune_hf_bert.py @@ -26,7 +26,7 @@ pruner = tp.pruner.MetaPruner( model, example_inputs, - global_pruning=True, # If False, a uniform pruning ratio will be assigned to different layers. + global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers. importance=imp, # importance criterion for parameter selection iterative_steps=1, # the number of iterations to achieve target pruning ratio pruning_ratio=0.5, diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index ef8143e3..41a2e1e1 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -507,12 +507,14 @@ def prune_global(self) -> typing.Generator: sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size pruning_indices.append(sub_pruning_idxs) else: - pruning_indices = (imp <= thres).nonzero().view(-1) + _pruning_indices = (imp <= thres).nonzero().view(-1) if len(pruning_indices)>0 and self.round_to: - n_pruned = len(pruning_indices) + n_pruned = len(_pruning_indices) current_channels = get_channel_fn(module) n_pruned = self._round_to(n_pruned, current_channels, self.round_to) - pruning_indices = pruning_indices[:n_pruned] + _pruning_indices = _pruning_indices[:n_pruned] + pruning_indices.append(_pruning_indices) + # Prune heads if len(global_head_importance)>0: if group in global_head_importance and n_heads_removed>0: @@ -523,11 +525,12 @@ def prune_global(self) -> typing.Generator: pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) ) for qkv_layer in qkv_layers: self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning - pruning_indices = torch.cat(pruning_indices, 0) - pruning_indices = torch.unique(pruning_indices) + if len(pruning_indices)>0: + pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist() + # create pruning group group = self.DG.get_pruning_group( - module, pruning_fn, pruning_indices.tolist()) + module, pruning_fn, pruning_indices) if self.DG.check_pruning_group(group): yield group