Skip to content

Commit

Permalink
[#410] Fixed a bug in transformer head pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Aug 6, 2024
1 parent 70f2622 commit 594cd9f
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,18 @@ def __init__(

###############################################
# Count the number of total channels at initialization
if self.global_pruning:
initial_total_channels = 0
initial_total_heads = 0
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
group = self._downstream_node_as_root_if_attention(group)
initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
for dep, _ in group:
if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler):
initial_total_heads += self.num_heads[dep.target.module]
break # only count heads once
self.initial_total_channels = initial_total_channels
self.initial_total_heads = initial_total_heads
#if self.global_pruning:
initial_total_channels = 0
initial_total_heads = 0
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
group = self._downstream_node_as_root_if_attention(group)
initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
for dep, _ in group:
if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler):
initial_total_heads += self.num_heads[dep.target.module]
break # only count heads once
self.initial_total_channels = initial_total_channels
self.initial_total_heads = initial_total_heads


def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
Expand Down Expand Up @@ -565,7 +565,11 @@ def _prune(self) -> typing.Generator:
if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and n_heads_removed>0:
if group in ranking_scope[ATTN_HEAD_SCOPE]:
qkv_layers, head_imp = ranking_scope[ATTN_HEAD_SCOPE][group]
head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1)
if not self.global_pruning:
n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp))
head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking
else:
head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking
if len(head_pruning_indices)>0:
for head_id in head_pruning_indices:
pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) )
Expand Down

0 comments on commit 594cd9f

Please sign in to comment.