Skip to content

Commit

Permalink
Fixed a bug in head pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Oct 16, 2023
1 parent 74cf57f commit 5d90e6c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 46 deletions.
2 changes: 1 addition & 1 deletion examples/transformers/prune_timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main():
num_heads=num_heads, # number of heads in self attention
prune_num_heads=args.prune_num_heads, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims=not args.prune_num_heads, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=2
)

Expand Down
86 changes: 41 additions & 45 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(

if ch_sparsity is not None:
warnings.warn("ch_sparsity is deprecated in v1.3.0. Please use pruning_ratio.")
pruning_ratio = pruning_ratio_dict
pruning_ratio = pruning_ratio
if ch_sparsity_dict is not None:
warnings.warn("ch_sparsity_dict is deprecated in v1.3.0. Please use pruning_ratio_dict instead.")
pruning_ratio_dict = ch_sparsity_dict
Expand Down Expand Up @@ -324,26 +324,28 @@ def _round_to(self, n_pruned, current_channels, round_to):
rounded_channels = current_channels - n_pruned
rounded_channels = rounded_channels + (round_to - rounded_channels % round_to)
n_pruned = current_channels - rounded_channels
return n_pruned
return max(n_pruned, 0)

def prune_local(self) -> typing.Generator:
if self.current_step > self.iterative_steps:
warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.")
return

for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):

if self._check_pruning_ratio(group): # check pruning ratio

##################################
# Compute raw importance score
##################################
group = self._downstream_node_as_root_if_attention(group)

module = group[0][0].target.module
pruning_fn = group[0][0].handler
ch_groups = self._get_channel_groups(group)

imp = self.estimate_importance(group, ch_groups=ch_groups)
if imp is None: continue

##################################
# Compute the number of dims/channels to prune
##################################
if self.DG.is_out_channel_pruning_fn(pruning_fn):
current_channels = self.DG.get_out_channels(module)
target_pruning_ratio = self.get_target_pruning_ratio(module)
Expand All @@ -358,52 +360,47 @@ def prune_local(self) -> typing.Generator:
self.layer_init_in_ch[module] *
(1 - target_pruning_ratio)
)

# round to the nearest multiple of round_to
if self.round_to:
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)

if n_pruned <= 0:
continue

if ch_groups > 1: # independent pruning for each group
group_size = current_channels // ch_groups
pruning_idxs = []

_is_attn, qkv_layers = self._is_attn_group(group)
n_heads_removed = 0
head_pruning_idxs = []
if _is_attn and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0: # Prune entire attn heads
target_head_pruning_ratio = self.get_target_head_pruning_ratio(module)
n_heads_removed = self.num_heads[qkv_layers[0]] - int(self.init_num_heads[qkv_layers[0]] * (1 - target_head_pruning_ratio))
head_imp = imp.view(ch_groups, -1).mean(1)
for head_id in torch.argsort(head_imp)[:n_heads_removed]:
head_pruning_idxs.append( torch.arange(head_id*group_size, (head_id+1)*group_size) )
if len(head_pruning_idxs)>0:
head_pruning_idxs = torch.cat(head_pruning_idxs, 0).tolist()

# Prune attn dims or grouped dims
dim_pruning_idxs = []
##################################
# collect pruning idxs
##################################
pruning_idxs = []
_is_attn, qkv_layers = self._is_attn_group(group)
group_size = current_channels // ch_groups
# dims/channels
if n_pruned > 0:
if (self.prune_head_dims and _is_attn) or (not _is_attn):
n_pruned_per_group = n_pruned // ch_groups
if n_pruned_per_group == 0: continue # skip
for chg in range(ch_groups):
sub_group_imp = imp[chg*group_size: (chg+1)*group_size]
sub_imp_argsort = torch.argsort(sub_group_imp)
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group] + chg*group_size # offset
dim_pruning_idxs.append(sub_pruning_idxs)
dim_pruning_idxs = torch.cat(dim_pruning_idxs, 0).tolist()
pruning_idxs = list(set(dim_pruning_idxs + head_pruning_idxs))
pruning_idxs.sort()
if n_pruned_per_group>0:
for chg in range(ch_groups):
sub_group_imp = imp[chg*group_size: (chg+1)*group_size]
sub_imp_argsort = torch.argsort(sub_group_imp)
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group] + chg*group_size # offset
pruning_idxs.append(sub_pruning_idxs)
else: # no channel grouping
imp_argsort = torch.argsort(imp)
pruning_idxs = imp_argsort[:n_pruned].tolist()
pruning_idxs.append( imp_argsort[:n_pruned] )

# num heads
if _is_attn and self.prune_num_heads: # Prune entire attn heads
target_head_pruning_ratio = self.get_target_head_pruning_ratio(qkv_layers[0])
n_heads_removed = self.num_heads[qkv_layers[0]] - int(self.init_num_heads[qkv_layers[0]] * (1 - target_head_pruning_ratio))
if n_heads_removed>0:
head_imp = imp.view(ch_groups, -1).mean(1)
for head_id in torch.argsort(head_imp)[:n_heads_removed]:
pruning_idxs.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) )

if len(pruning_idxs)==0: continue
pruning_idxs = torch.unique( torch.cat(pruning_idxs, 0) ).tolist()
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs)

if self.DG.check_pruning_group(group):
# Update num heads after pruning
if ch_groups > 1 and n_heads_removed>0:
if _is_attn and self.prune_num_heads and n_heads_removed>0:
for dep, _ in group:
if dep.target.module in self.num_heads:
self.num_heads[dep.target.module] -= n_heads_removed
Expand Down Expand Up @@ -489,7 +486,7 @@ def prune_global(self) -> typing.Generator:

# Prune feature dims/channels
pruning_indices = []
if len(global_importance)>0:
if len(global_importance)>0 and n_pruned>0:
if ch_groups > 1: # re-compute importance for each channel group if channel grouping is enabled
n_pruned_per_group = len((imp <= thres).nonzero().view(-1))
if n_pruned_per_group>0:
Expand All @@ -516,8 +513,8 @@ def prune_global(self) -> typing.Generator:
pruning_indices.append(_pruning_indices)

# Prune heads
if len(global_head_importance)>0:
if group in global_head_importance and n_heads_removed>0:
if len(global_head_importance)>0 and n_heads_removed>0:
if group in global_head_importance:
qkv_layers, head_imp = global_head_importance[group]
head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1)
if len(head_pruning_indices)>0:
Expand All @@ -526,9 +523,8 @@ def prune_global(self) -> typing.Generator:
for qkv_layer in qkv_layers:
self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning

if len(pruning_indices)>0:
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()

if len(pruning_indices)==0: continue
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()
# create pruning group
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices)
Expand Down

0 comments on commit 5d90e6c

Please sign in to comment.