Skip to content

Commit

Permalink
Fixed a bug in global head pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Oct 16, 2023
1 parent 50bf7d9 commit 74cf57f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/transformers/prune_hf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=True, # If False, a uniform pruning ratio will be assigned to different layers.
global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target pruning ratio
pruning_ratio=0.5,
Expand Down
15 changes: 9 additions & 6 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,14 @@ def prune_global(self) -> typing.Generator:
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size
pruning_indices.append(sub_pruning_idxs)
else:
pruning_indices = (imp <= thres).nonzero().view(-1)
_pruning_indices = (imp <= thres).nonzero().view(-1)
if len(pruning_indices)>0 and self.round_to:
n_pruned = len(pruning_indices)
n_pruned = len(_pruning_indices)
current_channels = get_channel_fn(module)
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
pruning_indices = pruning_indices[:n_pruned]
_pruning_indices = _pruning_indices[:n_pruned]
pruning_indices.append(_pruning_indices)

# Prune heads
if len(global_head_importance)>0:
if group in global_head_importance and n_heads_removed>0:
Expand All @@ -523,11 +525,12 @@ def prune_global(self) -> typing.Generator:
pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) )
for qkv_layer in qkv_layers:
self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning
pruning_indices = torch.cat(pruning_indices, 0)
pruning_indices = torch.unique(pruning_indices)

if len(pruning_indices)>0:
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()

# create pruning group
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices.tolist())
module, pruning_fn, pruning_indices)
if self.DG.check_pruning_group(group):
yield group

0 comments on commit 74cf57f

Please sign in to comment.