diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py index e8833ebc..b0002bc9 100644 --- a/examples/transformers/prune_timm_vit.py +++ b/examples/transformers/prune_timm_vit.py @@ -150,7 +150,7 @@ def main(): num_heads=num_heads, # number of heads in self attention prune_num_heads=args.prune_num_heads, # reduce num_heads by pruning entire heads (default: False) prune_head_dims=not args.prune_num_heads, # reduce head_dim by pruning featrues dims of each head (default: True) - head_pruning_ratio=args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0) + head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0) round_to=2 ) diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 41a2e1e1..2bf96254 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -84,7 +84,7 @@ def __init__( if ch_sparsity is not None: warnings.warn("ch_sparsity is deprecated in v1.3.0. Please use pruning_ratio.") - pruning_ratio = pruning_ratio_dict + pruning_ratio = pruning_ratio if ch_sparsity_dict is not None: warnings.warn("ch_sparsity_dict is deprecated in v1.3.0. Please use pruning_ratio_dict instead.") pruning_ratio_dict = ch_sparsity_dict @@ -324,7 +324,7 @@ def _round_to(self, n_pruned, current_channels, round_to): rounded_channels = current_channels - n_pruned rounded_channels = rounded_channels + (round_to - rounded_channels % round_to) n_pruned = current_channels - rounded_channels - return n_pruned + return max(n_pruned, 0) def prune_local(self) -> typing.Generator: if self.current_step > self.iterative_steps: @@ -332,18 +332,20 @@ def prune_local(self) -> typing.Generator: return for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): - if self._check_pruning_ratio(group): # check pruning ratio - + ################################## + # Compute raw importance score + ################################## group = self._downstream_node_as_root_if_attention(group) - module = group[0][0].target.module pruning_fn = group[0][0].handler ch_groups = self._get_channel_groups(group) - imp = self.estimate_importance(group, ch_groups=ch_groups) if imp is None: continue + ################################## + # Compute the number of dims/channels to prune + ################################## if self.DG.is_out_channel_pruning_fn(pruning_fn): current_channels = self.DG.get_out_channels(module) target_pruning_ratio = self.get_target_pruning_ratio(module) @@ -358,52 +360,47 @@ def prune_local(self) -> typing.Generator: self.layer_init_in_ch[module] * (1 - target_pruning_ratio) ) - # round to the nearest multiple of round_to if self.round_to: n_pruned = self._round_to(n_pruned, current_channels, self.round_to) - - if n_pruned <= 0: - continue - - if ch_groups > 1: # independent pruning for each group - group_size = current_channels // ch_groups - pruning_idxs = [] - - _is_attn, qkv_layers = self._is_attn_group(group) - n_heads_removed = 0 - head_pruning_idxs = [] - if _is_attn and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0: # Prune entire attn heads - target_head_pruning_ratio = self.get_target_head_pruning_ratio(module) - n_heads_removed = self.num_heads[qkv_layers[0]] - int(self.init_num_heads[qkv_layers[0]] * (1 - target_head_pruning_ratio)) - head_imp = imp.view(ch_groups, -1).mean(1) - for head_id in torch.argsort(head_imp)[:n_heads_removed]: - head_pruning_idxs.append( torch.arange(head_id*group_size, (head_id+1)*group_size) ) - if len(head_pruning_idxs)>0: - head_pruning_idxs = torch.cat(head_pruning_idxs, 0).tolist() - # Prune attn dims or grouped dims - dim_pruning_idxs = [] + ################################## + # collect pruning idxs + ################################## + pruning_idxs = [] + _is_attn, qkv_layers = self._is_attn_group(group) + group_size = current_channels // ch_groups + # dims/channels + if n_pruned > 0: if (self.prune_head_dims and _is_attn) or (not _is_attn): n_pruned_per_group = n_pruned // ch_groups - if n_pruned_per_group == 0: continue # skip - for chg in range(ch_groups): - sub_group_imp = imp[chg*group_size: (chg+1)*group_size] - sub_imp_argsort = torch.argsort(sub_group_imp) - sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group] + chg*group_size # offset - dim_pruning_idxs.append(sub_pruning_idxs) - dim_pruning_idxs = torch.cat(dim_pruning_idxs, 0).tolist() - pruning_idxs = list(set(dim_pruning_idxs + head_pruning_idxs)) - pruning_idxs.sort() + if n_pruned_per_group>0: + for chg in range(ch_groups): + sub_group_imp = imp[chg*group_size: (chg+1)*group_size] + sub_imp_argsort = torch.argsort(sub_group_imp) + sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group] + chg*group_size # offset + pruning_idxs.append(sub_pruning_idxs) else: # no channel grouping imp_argsort = torch.argsort(imp) - pruning_idxs = imp_argsort[:n_pruned].tolist() + pruning_idxs.append( imp_argsort[:n_pruned] ) + + # num heads + if _is_attn and self.prune_num_heads: # Prune entire attn heads + target_head_pruning_ratio = self.get_target_head_pruning_ratio(qkv_layers[0]) + n_heads_removed = self.num_heads[qkv_layers[0]] - int(self.init_num_heads[qkv_layers[0]] * (1 - target_head_pruning_ratio)) + if n_heads_removed>0: + head_imp = imp.view(ch_groups, -1).mean(1) + for head_id in torch.argsort(head_imp)[:n_heads_removed]: + pruning_idxs.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) ) + + if len(pruning_idxs)==0: continue + pruning_idxs = torch.unique( torch.cat(pruning_idxs, 0) ).tolist() group = self.DG.get_pruning_group( module, pruning_fn, pruning_idxs) if self.DG.check_pruning_group(group): # Update num heads after pruning - if ch_groups > 1 and n_heads_removed>0: + if _is_attn and self.prune_num_heads and n_heads_removed>0: for dep, _ in group: if dep.target.module in self.num_heads: self.num_heads[dep.target.module] -= n_heads_removed @@ -489,7 +486,7 @@ def prune_global(self) -> typing.Generator: # Prune feature dims/channels pruning_indices = [] - if len(global_importance)>0: + if len(global_importance)>0 and n_pruned>0: if ch_groups > 1: # re-compute importance for each channel group if channel grouping is enabled n_pruned_per_group = len((imp <= thres).nonzero().view(-1)) if n_pruned_per_group>0: @@ -516,8 +513,8 @@ def prune_global(self) -> typing.Generator: pruning_indices.append(_pruning_indices) # Prune heads - if len(global_head_importance)>0: - if group in global_head_importance and n_heads_removed>0: + if len(global_head_importance)>0 and n_heads_removed>0: + if group in global_head_importance: qkv_layers, head_imp = global_head_importance[group] head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) if len(head_pruning_indices)>0: @@ -526,9 +523,8 @@ def prune_global(self) -> typing.Generator: for qkv_layer in qkv_layers: self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning - if len(pruning_indices)>0: - pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist() - + if len(pruning_indices)==0: continue + pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist() # create pruning group group = self.DG.get_pruning_group( module, pruning_fn, pruning_indices)