From 594cd9fefd19ec1878b75b865a997c4fd6cba47b Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Tue, 6 Aug 2024 14:48:20 +0800 Subject: [PATCH] [#410] Fixed a bug in transformer head pruning --- torch_pruning/pruner/algorithms/metapruner.py | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 4abd778..28fa644 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -238,18 +238,18 @@ def __init__( ############################################### # Count the number of total channels at initialization - if self.global_pruning: - initial_total_channels = 0 - initial_total_heads = 0 - for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): - group = self._downstream_node_as_root_if_attention(group) - initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) ) - for dep, _ in group: - if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler): - initial_total_heads += self.num_heads[dep.target.module] - break # only count heads once - self.initial_total_channels = initial_total_channels - self.initial_total_heads = initial_total_heads + #if self.global_pruning: + initial_total_channels = 0 + initial_total_heads = 0 + for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): + group = self._downstream_node_as_root_if_attention(group) + initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) ) + for dep, _ in group: + if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler): + initial_total_heads += self.num_heads[dep.target.module] + break # only count heads once + self.initial_total_channels = initial_total_channels + self.initial_total_heads = initial_total_heads def step(self, interactive=False)-> typing.Union[typing.Generator, None]: @@ -565,7 +565,11 @@ def _prune(self) -> typing.Generator: if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and n_heads_removed>0: if group in ranking_scope[ATTN_HEAD_SCOPE]: qkv_layers, head_imp = ranking_scope[ATTN_HEAD_SCOPE][group] - head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) + if not self.global_pruning: + n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp)) + head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking + else: + head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking if len(head_pruning_indices)>0: for head_id in head_pruning_indices: pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) )