Skip to content

Commit

Permalink
An example for Timm ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jul 26, 2023
1 parent 71028c1 commit a8027de
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
57 changes: 57 additions & 0 deletions examples/hf_transformers/prune_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers.models.vit.modeling_vit import ViTSelfAttention
import torch_pruning as tp
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
example_inputs = processor(images=image, return_tensors="pt")["pixel_values"]
#outputs = model(example_inputs)
#logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
#predicted_class_idx = logits.argmax(-1).item()
#print("Predicted class:", model.config.id2label[predicted_class_idx])

print(model)
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
channel_groups = {}

# All heads should be pruned simultaneously, so we group channels by head.
for m in model.modules():
if isinstance(m, ViTSelfAttention):
channel_groups[m.query] = m.num_attention_heads
channel_groups[m.key] = m.num_attention_heads
channel_groups[m.value] = m.num_attention_heads

pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target sparsity
ch_sparsity=0.5,
channel_groups=channel_groups,
output_transform=lambda out: out.logits.sum(),
ignored_layers=[model.classifier],
)

for g in pruner.step(interactive=True):
#print(g)
g.prune()

# Modify the attention head size and all head size aftering pruning
for m in model.modules():
if isinstance(m, ViTSelfAttention):
m.attention_head_size = m.query.out_features // m.num_attention_heads
m.all_head_size = m.query.out_features

print(model)
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %d G, Pruned MACs: %d G"%(base_macs/1e9, pruned_macs/1e9))
print("Base Params: %d M, Pruned Params: %d M"%(base_params/1e6, pruned_params/1e6))
6 changes: 6 additions & 0 deletions examples/hf_transformers/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Example for HuggingFace ViT

## Pruning
```bash
python prune_vit.py
```
5 changes: 1 addition & 4 deletions torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _detect_unwrapped_parameters(self, unwrapped_parameters):
unwrapped_parameters = []
unwrapped_detected = list( set(unwrapped_detected) - set([p for (p, _) in unwrapped_parameters]) )
if len(unwrapped_detected)>0 and self.verbose:
warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected])
warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected])
warnings.warn(warning_str)

# set default pruning dim for unwrapped parameters
Expand Down Expand Up @@ -782,9 +782,6 @@ def create_node_if_not_exists(grad_fn):
module = ops._SplitOp(self._op_id)
self._op_id+=1
elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower():
#if 'reshape' in grad_fn.name().lower():
#print(grad_fn.__dir__())
#print(grad_fn._saved_self_sizes)
module = ops._ReshapeOp(self._op_id)
self._op_id+=1
else:
Expand Down

0 comments on commit a8027de

Please sign in to comment.