Skip to content

Commit

Permalink
Merge pull request #234 from VainF/v1.2
Browse files Browse the repository at this point in the history
V1.2.2
  • Loading branch information
VainF authored Aug 14, 2023
2 parents 6051fa5 + 6dbc22d commit 2cc7dcf
Show file tree
Hide file tree
Showing 17 changed files with 325 additions and 330 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_torch_181.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest --ignore=tests/test_unwrapped_parameters.py --ignore=tests/test_backward.py --ignore=tests/test_concat_split.py --ignore=tests/test_serialization.py
pytest --ignore=tests/test_unwrapped_parameters.py --ignore=tests/test_backward.py --ignore=tests/test_concat_split.py --ignore=tests/test_serialization.py --ignore=tests/test_non_feature_dim_cat.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.8 %20%7C%201.12 %20%7C%202.0-673ab7.svg" alt="Tested PyTorch Versions"></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.2.1-3f51b5.svg" alt="Latest Version"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.2.2-3f51b5.svg" alt="Latest Version"></a>
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_pruner(model, example_inputs):
pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning)
elif args.method == "group_sl":
args.sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
imp = tp.importance.GroupNormImportance(p=2, normalizer='max') # normalized by the maximum score for CIFAR
pruner_entry = partial(tp.pruner.GroupNormPruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "growing_reg":
args.sparsity_learning = True
Expand Down
52 changes: 52 additions & 0 deletions examples/hf_transformers/prune_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from transformers import AutoTokenizer, BertModel
import torch
from transformers.models.bert.modeling_bert import BertSelfAttention
import torch_pruning as tp

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
#print(model)
hf_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
example_inputs = {'input_ids': hf_inputs['input_ids'], 'token_type_ids': hf_inputs['token_type_ids'], 'attention_mask': hf_inputs['attention_mask']}

#outputs = model(**example_inputs)
#last_hidden_states = outputs.last_hidden_state

imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
channel_groups = {}

# All heads should be pruned simultaneously, so we group channels by head.
for m in model.modules():
if isinstance(m, BertSelfAttention):
channel_groups[m.query] = m.num_attention_heads
channel_groups[m.key] = m.num_attention_heads
channel_groups[m.value] = m.num_attention_heads

pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target sparsity
ch_sparsity=0.5,
channel_groups=channel_groups,
output_transform=lambda out: out.pooler_output.sum(),
ignored_layers=[model.pooler],
)

for g in pruner.step(interactive=True):
print(g)
g.prune()

# Modify the attention head size and all head size aftering pruning
for m in model.modules():
if isinstance(m, BertSelfAttention):
m.attention_head_size = m.query.out_features // m.num_attention_heads
m.all_head_size = m.query.out_features

print(model)
test_output = model(**example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %f M, Pruned MACs: %f M"%(base_macs/1e6, pruned_macs/1e6))
print("Base Params: %f M, Pruned Params: %f M"%(base_params/1e6, pruned_params/1e6))
3 changes: 3 additions & 0 deletions examples/timm_models/timm_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
problem_with_input_shape = []
for i, model_name in enumerate(timm_models):
print("Pruning %s..."%model_name)
if "botnet" in model_name or "coatnet" in model_name or "coatnext" in model_name:
unprunable_list.append(model_name)
continue
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them.
# unprunable_list.append(model_name)
Expand Down
11 changes: 9 additions & 2 deletions examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def my_prune(model, example_inputs, output_transform, model_name):
importance=importance,
iterative_steps=1,
ch_sparsity=0.5,
global_pruning=False,
round_to=round_to,
unwrapped_parameters=unwrapped_parameters,
ignored_layers=ignored_layers,
Expand Down Expand Up @@ -256,7 +257,7 @@ def my_prune(model, example_inputs, output_transform, model_name):
successful = []
unsuccessful = []
for model_name, entry in entries.items():
if 'swin' in model_name.lower(): # stuck
if 'swin' in model_name.lower() or 'raft' in model_name.lower() or 'shufflenet' in model_name.lower(): # stuck
unsuccessful.append(model_name)
continue

Expand Down Expand Up @@ -303,4 +304,10 @@ def my_prune(model, example_inputs, output_transform, model_name):
print("Successful Pruning: %d Models\n"%(len(successful)), successful)
print("")
print("Unsuccessful Pruning: %d Models\n"%(len(unsuccessful)), unsuccessful)
sys.stdout.flush()
sys.stdout.flush()

print("Finished!")

print("Successful Pruning: %d Models\n"%(len(successful)), successful)
print("")
print("Unsuccessful Pruning: %d Models\n"%(len(unsuccessful)), unsuccessful)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="torch-pruning",
version="v1.2.1",
version="v1.2.2",
author="Gongfan Fang",
author_email="[email protected]",
description="Towards Any Structural Pruning",
Expand Down
82 changes: 82 additions & 0 deletions tests/test_non_feature_dim_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import sys, os

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import torch
import torch_pruning as tp
import torch.nn as nn

class Net(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim),
nn.GELU(),
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim)
)
self.parallel_path = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim),
nn.GELU(),
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim)
)

self.conv1 = nn.Conv2d(in_dim, in_dim, 1)
self.conv2 = nn.Conv2d(in_dim, in_dim, 1)

def forward(self, x):
x1 = self.block1(x)
x2 = self.parallel_path(x)
x = torch.cat([x1, x2], dim=2)
x = self.conv1(x)
x1, x2 = torch.split(x, [x1.shape[2], x2.shape[2]], dim=2)
x = self.conv2(x1)
return x

def test_pruner():
model = Net(512)
print(model)
# Global metrics
example_inputs = torch.randn(1, 512, 7, 7)
imp = tp.importance.MagnitudeImportance(p=2)
ignored_layers = []

# DO NOT prune the final classifier!
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)

iterative_steps = 1
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
pruner.step()
print(model)
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

print(model(example_inputs).shape)
print(
" Iter %d/%d, Params: %.2f => %.2f"
% (i+1, iterative_steps, base_nparams, nparams)
)
print(
" Iter %d/%d, MACs: %.2f => %.2f"
% (i+1, iterative_steps, base_macs, macs)
)
# finetune your model here
# finetune(model)
# ...

if __name__=='__main__':
test_pruner()
54 changes: 39 additions & 15 deletions torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
__all__ = ["Dependency", "Group", "DependencyGraph"]

_PLACEHOLDER = None
MAX_RECURSION_DEPTH = 100

def equal_func(func1, func2):
return (
Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
self.target = target
# Current coordinate system => Standard coordinate system => target coordinate system
# index_mapping[0] index_mapping[1]
self.index_mapping = [None, None]
self.index_mapping = [_PLACEHOLDER, _PLACEHOLDER] # [None, None] by default

def __call__(self, idxs: list):
self.handler.__self__.pruning_dim = self.target.pruning_dim # set pruning_dim
Expand Down Expand Up @@ -568,15 +569,18 @@ def get_in_channels(self, module_or_node):
return None
return p.get_in_channels(module)

def _infer_out_channels_recursively(self, node: Node):
def _infer_out_channels_recursively(self, node: Node, recursive_depth: list):
""" infer the number of output channels recursively
"""
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return None
ch = self.get_out_channels(node)
if ch is None:
ch = 0
for in_node in node.inputs:
if node.type == ops.OPTYPE.CONCAT:
sub_ch = self._infer_out_channels_recursively(in_node)
recursive_depth[0]+=1
sub_ch = self._infer_out_channels_recursively(in_node, recursive_depth)
if sub_ch is None:
return None
ch += sub_ch
Expand All @@ -586,25 +590,30 @@ def _infer_out_channels_recursively(self, node: Node):
if split_out_node == node:
ch = in_node.module.split_sizes[i]
else:
ch = self._infer_out_channels_recursively(in_node)
recursive_depth[0]+=1
ch = self._infer_out_channels_recursively(in_node, recursive_depth)
if ch == 0:
return None
return ch

def _infer_in_channels_recursively(self, node: Node):
def _infer_in_channels_recursively(self, node: Node, recursive_depth: list):
""" infer the number of input channels recursively
"""
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return None
ch = self.get_in_channels(node)
if ch is None:
ch = 0
for out_node in node.outputs:
if node.type == ops.OPTYPE.SPLIT:
sub_ch = self._infer_in_channels_recursively(out_node)
recursive_depth[0]+=1
sub_ch = self._infer_in_channels_recursively(out_node, recursive_depth)
if sub_ch is None:
return None
ch += sub_ch
else:
ch = self._infer_in_channels_recursively(out_node)
recursive_depth[0]+=1
ch = self._infer_in_channels_recursively(out_node, recursive_depth)
if ch == 0:
return None
return ch
Expand Down Expand Up @@ -707,7 +716,7 @@ def _record_grad_fn(module, inputs, outputs):
for m in model.modules()
if (isinstance(m, registered_types) and m not in self.IGNORED_LAYERS)
]

# Feed forward to record gradient functions of prunable modules
if forward_fn is not None:
out = forward_fn(model, example_inputs)
Expand Down Expand Up @@ -875,7 +884,8 @@ def _init_shape_information(self):
else: # legency version
chs = []
for n in node.outputs:
chs.append(self._infer_in_channels_recursively(n))
recursive_depth = [0]
chs.append(self._infer_in_channels_recursively(n, recursive_depth))
offsets = [0]
for ch in chs:
if ch is None: continue
Expand All @@ -889,7 +899,8 @@ def _update_flatten_index_mapping(self, fc_node: Node):
fc_in_features = fc_node.module.in_features
feature_channels = 0
for n in fc_node.inputs:
feature_channels = self._infer_out_channels_recursively(n)
recursive_depth = [0]
feature_channels = self._infer_out_channels_recursively(n, recursive_depth)
if feature_channels is not None: # =0 if there is a residual connection to model inputs
break
if (
Expand Down Expand Up @@ -924,13 +935,19 @@ def _update_reshape_index_mapping(self, reshape_node: Node):

out_channels = None
for n in reshape_node.outputs:
out_channels = self._infer_in_channels_recursively(n)
recursive_depth = [0]
out_channels = self._infer_in_channels_recursively(n, recursive_depth)
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return
if out_channels is not None: # =0 if there is a residual connection to model inputs
break

in_channels = None
for n in reshape_node.inputs:
in_channels = self._infer_out_channels_recursively(n)
recursive_depth = [0]
in_channels = self._infer_out_channels_recursively(n, recursive_depth)
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return
if in_channels is not None: # =0 if there is a residual connection to model inputs
break

Expand Down Expand Up @@ -976,7 +993,10 @@ def _update_reshape_index_mapping(self, reshape_node: Node):
def _update_concat_index_mapping(self, cat_node: Node):
if cat_node.type != ops.OPTYPE.CONCAT:
return


if hasattr(cat_node.grad_fn, '_saved_dim') and cat_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12
return

if cat_node.module.concat_sizes is not None:
chs = cat_node.module.concat_sizes
else:
Expand Down Expand Up @@ -1024,7 +1044,10 @@ def _update_concat_index_mapping(self, cat_node: Node):
def _update_split_index_mapping(self, split_node: Node):
if split_node.type != ops.OPTYPE.SPLIT:
return


if hasattr(split_node.grad_fn, '_saved_dim') and split_node.grad_fn._saved_dim != 1: # this only works for Pytorch>=1.12
return

offsets = split_node.module.offsets
if offsets is None:
return
Expand Down Expand Up @@ -1057,7 +1080,8 @@ def infer_channels_between(self, node_1, node_2):
for i, n in enumerate(node_1.outputs):
if n == node_2:
return node_1.module.split_sizes[i]
return self._infer_out_channels_recursively(node_1)
recursive_depth = [0]
return self._infer_out_channels_recursively(node_1, recursive_depth)



Expand Down
4 changes: 4 additions & 0 deletions torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from ..importance import MagnitudeImportance

class BNScalePruner(MetaPruner):
"""Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/abs/1708.06519
"""

def __init__(
self,
model,
Expand Down
Loading

0 comments on commit 2cc7dcf

Please sign in to comment.