Skip to content

Commit

Permalink
Merge pull request #403 from VainF/v2.0
Browse files Browse the repository at this point in the history
v1.4.1
  • Loading branch information
VainF authored Jul 21, 2024
2 parents 2211538 + 2bab065 commit d80bcc2
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 217 deletions.
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

Torch-Pruning (TP) is designed for structural pruning, facilating the following features:

* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), [Segment Anything Model (SAM)](https://github.com/czg1225/SlimSAM), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Vision Transformers](examples/transformers/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically.
* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including [Large Language Models (LLMs)](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), [Segment Anything Model (SAM)](https://github.com/czg1225/SlimSAM), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Vision Transformers](https://github.com/VainF/Isomorphic-Pruning), [ConvNext](https://github.com/VainF/Isomorphic-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, DenseNet, RegNet, DeepLab, etc. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to remove parameters physically.
* [Examples](examples): Pruning off-the-shelf models from Timm, Huggingface Transformers, Torchvision, Yolo, etc.
* [Code for reproducing paper results](reproduce): Reproduce the our results in the DepGraph paper.

Expand Down Expand Up @@ -206,15 +206,30 @@ print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -
MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M
```
#### Global Pruning
Global pruning perform importance ranking across all layers, which has the potential to find a better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be eanbled with ``isomorphic=True``. For more details, please see our offical [codebase for ViTs and CNNs](https://github.com/VainF/Isomorphic-Pruning).
Global pruning perform importance ranking across all layers, which has the potential to find a better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be eanbled with ``isomorphic=True``.
```python
pruner = tp.pruner.MetaPruner(
...
ismorphic=True, # enable isomorphic pruning to improve global ranking
isomorphic=True, # enable isomorphic pruning to improve global ranking
global_pruning=True, # global pruning
)
```

<div align="center">
<img src="assets/isomorphic_pruning.png" width="96%">
</div>

#### Pruning Ratios

The default pruning ratio can be set by ``pruning_ratio``. If you want to customize the pruning ratio for some layers or blocks, you can use ``pruning_ratio_dict``. The key of the dict can be an ``nn.Module`` or a tuple of ``nn.Module``. In the second case, all modules in the tuple will form a ``scope`` and share the pruning ratio. Global ranking will be perfomed in this scope. This is also the core idea of [Isomorphic Pruning](https://arxiv.org/abs/2407.04616).
```python
pruner = tp.pruner.MetaPruner(
...
pruning_ratio=0.5, # default pruning ratio
pruning_ratio_dict = {(model.layer1, model.layer2): 0.4, model.layer3: 0.2},
# Global pruning will be performed on layer1 and layer2
)
```

#### Sparse Training (Optional)
Some pruners like [BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45) and [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53) support sparse training. This can be easily achieved by inserting ``pruner.update_regularizer()`` and ``pruner.regularize(model)`` in your standard training loops. The pruner will accumulate the regularization gradients to ``.grad``. Sparse training is optional and may be expensive for pruning.
Expand Down
Binary file added assets/isomorphic_pruning.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions examples/timm_models/prune_timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def main():
num_heads = {}

for m in model.modules():
if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(model.head)
print("Ignore classifier layer: ", m.head)
#if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)

# Attention layers
if hasattr(m, 'num_heads'):
Expand Down
3 changes: 3 additions & 0 deletions examples/transformers/readme.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Transformers

This example demonstrate the minimal code to prune Transformers, including Vision Transformers (ViT), Swin Transformers, and BERT. If you need a more comprehensive example for pruning and finetuning, please refer to the [codebase for Isomorphic Pruning](https://github.com/VainF/Isomorphic-Pruning), where detailed instructions and pre-pruned models are available.


## Pruning ViT-ImageNet-21K-ft-1K from [Timm](https://github.com/huggingface/pytorch-image-models)

### Data
Expand Down
8 changes: 5 additions & 3 deletions tests/test_taylor_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ def test_taylor():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
iterative_steps = 1 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
global_pruning=True,
pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
pruning_ratio_dict={model.layer1: 0.5, (model.layer2, model.layer3): 0.5},
ignored_layers=ignored_layers,
)

print(model)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
if isinstance(imp, tp.importance.TaylorImportance):
Expand Down
Loading

0 comments on commit d80bcc2

Please sign in to comment.