Skip to content

Commit

Permalink
Fixed some issues in GQA pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Oct 7, 2024
2 parents 59a26a7 + 3e98708 commit b08aa0e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

<p align="center">
<a href="https://github.com/VainF/Torch-Pruning/actions"><img src="https://img.shields.io/badge/tests-passing-9c27b0.svg" alt="Test Status"></a>
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.12 %20%7C%202.0-673ab7.svg" alt="Tested PyTorch Versions"></a>
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.x %20%7C%202.x-673ab7.svg" alt="Tested PyTorch Versions"></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://static.pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.4.1-3f51b5.svg" alt="Latest Version"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.4.2-3f51b5.svg" alt="Latest Version"></a>
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Expand All @@ -33,7 +33,8 @@ For more technical details, please refer to our CVPR'23 paper:
> *[Learning and Vision Lab](http://lv-nus.org/), National University of Singapore*
### Update:
- :rocket: 2024.07.20 Add [**Isomorphic Pruning**](https://arxiv.org/abs/2407.04616). A SOTA method for Vision Transformers and Modern CNNs.
- 🔥 2024.09.27 Check our latest work, [**MaskLLM (NeurIPS 24 Spotlight)**](https://github.com/NVlabs/MaskLLM), for learnable semi-structured sparsity of LLMs.
- 🚀 2024.07.20 Add [**Isomorphic Pruning (ECCV'24)**](https://arxiv.org/abs/2407.04616). A SOTA method for Vision Transformers and Modern CNNs.

### **Features:**
- :zap: High-level Pruners: [MetaPruner](torch_pruning/pruner/algorithms/metapruner.py), [MagnitudePruner](https://arxiv.org/abs/1608.08710), [BNScalePruner](https://arxiv.org/abs/1708.06519), [GroupNormPruner](https://arxiv.org/abs/2301.12900), [GrowingRegPruner](https://arxiv.org/abs/2012.09243), RandomPruner, etc. A paper list is available [here](https://github.com/VainF/Torch-Pruning/wiki/0.-Paper-List).
Expand All @@ -46,9 +47,8 @@ For more technical details, please refer to our CVPR'23 paper:

### **Contact Us:**
Please do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper.
Or Join our Discord or WeChat group for a chat:
* Discord: [link](https://discord.gg/Pvd6hbYXRs)
* WeChat Group [Group-2](https://github.com/user-attachments/assets/4072cc1e-63d7-4f33-b003-1e8da516f421), [Group-1 (500/500, FULL)](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3).
Or Join our WeChat group for a chat:
* WeChat Group [Group-2](https://github.com/user-attachments/assets/3fe4c487-5a5b-43fd-bf64-a5ee62c3dec1) (>200/500), [Group-1](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3) (500/500, FULL).

## Table of Contents
- [Installation](#installation)
Expand Down Expand Up @@ -204,6 +204,10 @@ print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -
# ...
```
```
# Note: In TP, pruning ratio means channel pruning ratio.
# Since both in & out channels will be removed by p%,
# the corresponding parameter pruning ratio will be roughly 1-(1-p%)^2.
# In this example, 3.06 ~= 11.69 * (1-0.5)^2 = 2.92
MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M
```
#### Global Pruning and Isomorphic Pruning
Expand Down Expand Up @@ -435,7 +439,7 @@ Latency test on ResNet-50, Batch Size=64.
> *Xinyin Ma, Gongfan Fang, and Xinchao Wang*
> CVPR 2024
> **0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)
> **SlimSAM: 0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)
> *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang*
> Preprint 2023
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torchrun --nproc_per_node=8 finetune.py \
--model "output/pruned/vit_base_patch16_224_pruned_l1_uniform.pth" \
--epochs 3000 \
--epochs 300 \
--batch-size 256 \
--opt adamw \
--lr 0.00015 \
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="torch-pruning",
version="v1.4.1",
version="v1.4.2",
author="Gongfan Fang",
author_email="[email protected]",
description="Towards Any Structural Pruning",
Expand Down
7 changes: 2 additions & 5 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ def _prune(self) -> typing.Generator:
head_imp = imp.view(num_heads, -1).mean(1).cpu() # average importance by head.
ranking_scope[ATTN_HEAD_SCOPE][group] = (qkv_layers, head_imp)


# Scope 1: User-defined pruning ratios
# Scope 1: User-defined scope, such as layer-wise pruning_ratios
is_user_defined_scope = False
for dep, _ in group:
for module, pruning_fn in zip([dep.source.module, dep.target.module], [dep.trigger, dep.handler]):
Expand All @@ -463,9 +462,7 @@ def _prune(self) -> typing.Generator:
if is_user_defined_scope:
continue

# otherwise, use the default pruning ratio
record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp)

record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp) # otherwise, use the default pruning ratio
# Scope 2: Isomorphic Pruning
if self.isomorphic:
scope_name = "Isomorphic_" # we transform the graph structure into a string tag for easy comparison
Expand Down
5 changes: 3 additions & 2 deletions torch_pruning/pruner/algorithms/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List

def linear_scheduler(pruning_ratio_dict, steps):
return [((i) / float(steps)) * pruning_ratio_dict for i in range(steps+1)]
def linear_scheduler(pruning_ratio: float, steps: int) -> List[float]:
return [((i) / float(steps)) * pruning_ratio for i in range(steps + 1)]

0 comments on commit b08aa0e

Please sign in to comment.