Skip to content

Commit

Permalink
Add more Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Feb 21, 2024
1 parent b4f29f6 commit 7ec9176
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 19 deletions.
37 changes: 30 additions & 7 deletions torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__all__ = ["Dependency", "Group", "DependencyGraph"]

INDEX_MAPPING_PLACEHOLDER = None
MAX_RECURSION_DEPTH = 100
MAX_RECURSION_DEPTH = 500

class Node(object):
""" Node of DepGraph
Expand Down Expand Up @@ -421,12 +421,19 @@ def get_pruning_group(
pruning_fn: typing.Callable,
idxs: typing.Sequence[int],
) -> Group:
"""Get the pruning group of pruning_fn.
Args:
module (nn.Module): the to-be-pruned module/layer.
pruning_fn (Callable): the pruning function.
idxs (list or tuple): the indices of channels/dimensions.
grouped_idxs (bool): whether the indices are grouped. If True, idxs is a list of list, e.g., [[0,1,2], [3,4,5]], where each sublist is a group.
"""
Get the pruning group for a given module.
Args:
module (nn.Module): The module to be pruned.
pruning_fn (Callable): The pruning function.
idxs (list or tuple): The indices of channels/dimensions.
Returns:
Group: The pruning group containing the dependencies and indices.
Raises:
ValueError: If the module is not in the dependency graph.
"""
if module not in self.module2node:
raise ValueError(
Expand Down Expand Up @@ -500,6 +507,22 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
return merged_group

def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR)):
"""
Get all pruning groups for the given module. Groups are generated on the module typs specified in root_module_types.
Args:
ignored_layers (list): List of layers to be ignored during pruning.
root_module_types (tuple): Tuple of root module types to consider for pruning.
Yields:
list: A pruning group containing dependencies and their corresponding pruning handlers.
Example:
```python
for group in DG.get_all_groups(ignored_layers=[layer1, layer2], root_module_types=[nn.Conv2d]):
print(group)
```
"""
visited_layers = []
ignored_layers = ignored_layers+self.IGNORED_LAYERS_IN_TRACING

Expand Down
36 changes: 36 additions & 0 deletions torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,42 @@
class BNScalePruner(MetaPruner):
"""Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/abs/1708.06519
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* reg (float): regularization coefficient. Default: 1e-5.
* group_lasso (bool): use group lasso. Default: False.
* global_pruning (bool): enable global pruning. Default: False.
* pruning_ratio (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* pruning_ratio_dict (Dict[nn.Module, float]): layer-specific pruning ratio. Will cover pruning_ratio if specified. Default: None.
* max_pruning_ratio (float): the maximum pruning ratio. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_pruning_ratio_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
* in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
* out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
* num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
# Deprecated
* channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
* ch_sparsity (float): the same as pruning_ratio. Default: None.
* ch_sparsity_dict (Dict[nn.Module, float]): the same as pruning_ratio_dict. Default: None.
"""

def __init__(
Expand Down
37 changes: 37 additions & 0 deletions torch_pruning/pruner/algorithms/group_norm_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,43 @@
class GroupNormPruner(MetaPruner):
"""DepGraph: Towards Any Structural Pruning.
https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* reg (float): regularization coefficient. Default: 1e-5.
* alpha (float): regularization scaling factor, [2^0, 2^alpha]. Default: 4.
* global_pruning (bool): enable global pruning. Default: False.
* pruning_ratio (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* pruning_ratio_dict (Dict[nn.Module, float]): layer-specific pruning ratio. Will cover pruning_ratio if specified. Default: None.
* max_pruning_ratio (float): the maximum pruning ratio. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_pruning_ratio_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
* in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
* out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
* num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
# Deprecated
* channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
* ch_sparsity (float): the same as pruning_ratio. Default: None.
* ch_sparsity_dict (Dict[nn.Module, float]): the same as pruning_ratio_dict. Default: None.
"""
def __init__(
self,
Expand Down
36 changes: 36 additions & 0 deletions torch_pruning/pruner/algorithms/growing_reg_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,42 @@
class GrowingRegPruner(MetaPruner):
""" pruning with growing regularization
https://arxiv.org/abs/2012.09243
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* reg (float): regularization coefficient. Default: 1e-5.
* delta_reg (float): increment of regularization coefficient. Default: 1e-5.
* global_pruning (bool): enable global pruning. Default: False.
* pruning_ratio (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* pruning_ratio_dict (Dict[nn.Module, float]): layer-specific pruning ratio. Will cover pruning_ratio if specified. Default: None.
* max_pruning_ratio (float): the maximum pruning ratio. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_pruning_ratio_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
* in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
* out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
* num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
# Deprecated
* channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
* ch_sparsity (float): the same as pruning_ratio. Default: None.
* ch_sparsity_dict (Dict[nn.Module, float]): the same as pruning_ratio_dict. Default: None.
"""
def __init__(
self,
Expand Down
34 changes: 34 additions & 0 deletions torch_pruning/pruner/algorithms/magnitude_based_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,40 @@

class MagnitudePruner(MetaPruner):
""" Prune the smallest magnitude weights
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* global_pruning (bool): enable global pruning. Default: False.
* pruning_ratio (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* pruning_ratio_dict (Dict[nn.Module, float]): layer-specific pruning ratio. Will cover pruning_ratio if specified. Default: None.
* max_pruning_ratio (float): the maximum pruning ratio. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_pruning_ratio_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
* in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
* out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
* num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
# Deprecated
* channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
* ch_sparsity (float): the same as pruning_ratio. Default: None.
* ch_sparsity_dict (Dict[nn.Module, float]): the same as pruning_ratio_dict. Default: None.
"""
pass

16 changes: 16 additions & 0 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MetaPruner:
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
Expand Down Expand Up @@ -227,6 +228,21 @@ def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
for group in pruning_method():
group.prune()

def manual_prune(self, layer, pruning_fn, pruning_ratios_or_idxs):
if isinstance(pruning_ratios_or_idxs, float):
if self.DG.is_out_channel_pruning_fn(pruning_fn):
prunable_channels = self.DG.get_out_channels(layer)
else:
prunable_channels = self.DG.get_in_channels(layer)
full_group = self.DG.get_pruning_group(layer, pruning_fn, list(range(prunable_channels)))
imp = self.estimate_importance(full_group)
imp_argsort = torch.argsort(imp)
n_pruned = int(prunable_channels * (1 - pruning_ratios_or_idxs))
pruning_idxs = imp_argsort[:n_pruned]

group = self.DG.get_pruning_group(layer, pruning_fn, pruning_idxs)
group.prune()

def estimate_importance(self, group) -> torch.Tensor:
return self.importance(group)

Expand Down
10 changes: 3 additions & 7 deletions torch_pruning/pruner/function.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import torch
import torch.nn as nn

from .. import ops

from copy import deepcopy
from functools import reduce
from operator import mul
from abc import ABC, abstractclassmethod
from typing import Sequence, Tuple

from abc import ABC, abstractclassmethod, abstractmethod, abstractstaticmethod
from typing import Callable, Sequence, Tuple, Dict
from .. import ops

__all__=[
'BasePruningFunc',
Expand Down
Loading

0 comments on commit 7ec9176

Please sign in to comment.