diff --git a/examples/hf_transformers/prune_vit.py b/examples/hf_transformers/prune_vit.py new file mode 100644 index 00000000..3852d58c --- /dev/null +++ b/examples/hf_transformers/prune_vit.py @@ -0,0 +1,57 @@ +from transformers import ViTImageProcessor, ViTForImageClassification +from transformers.models.vit.modeling_vit import ViTSelfAttention +import torch_pruning as tp +from PIL import Image +import requests + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +example_inputs = processor(images=image, return_tensors="pt")["pixel_values"] +#outputs = model(example_inputs) +#logits = outputs.logits +# model predicts one of the 1000 ImageNet classes +#predicted_class_idx = logits.argmax(-1).item() +#print("Predicted class:", model.config.id2label[predicted_class_idx]) + +print(model) +imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") +base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) +channel_groups = {} + +# All heads should be pruned simultaneously, so we group channels by head. +for m in model.modules(): + if isinstance(m, ViTSelfAttention): + channel_groups[m.query] = m.num_attention_heads + channel_groups[m.key] = m.num_attention_heads + channel_groups[m.value] = m.num_attention_heads + +pruner = tp.pruner.MagnitudePruner( + model, + example_inputs, + global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target sparsity + ch_sparsity=0.5, + channel_groups=channel_groups, + output_transform=lambda out: out.logits.sum(), + ignored_layers=[model.classifier], + ) + +for g in pruner.step(interactive=True): + #print(g) + g.prune() + +# Modify the attention head size and all head size aftering pruning +for m in model.modules(): + if isinstance(m, ViTSelfAttention): + m.attention_head_size = m.query.out_features // m.num_attention_heads + m.all_head_size = m.query.out_features + +print(model) +test_output = model(example_inputs) +pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) +print("Base MACs: %d G, Pruned MACs: %d G"%(base_macs/1e9, pruned_macs/1e9)) +print("Base Params: %d M, Pruned Params: %d M"%(base_params/1e6, pruned_params/1e6)) \ No newline at end of file diff --git a/examples/hf_transformers/readme.md b/examples/hf_transformers/readme.md new file mode 100644 index 00000000..e13b33d9 --- /dev/null +++ b/examples/hf_transformers/readme.md @@ -0,0 +1,6 @@ +# Example for HuggingFace ViT + +## Pruning +```bash +python prune_vit.py +``` \ No newline at end of file diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index b6c9071e..dfed8892 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -634,7 +634,7 @@ def _detect_unwrapped_parameters(self, unwrapped_parameters): unwrapped_parameters = [] unwrapped_detected = list( set(unwrapped_detected) - set([p for (p, _) in unwrapped_parameters]) ) if len(unwrapped_detected)>0 and self.verbose: - warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]) + warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]) warnings.warn(warning_str) # set default pruning dim for unwrapped parameters @@ -782,9 +782,6 @@ def create_node_if_not_exists(grad_fn): module = ops._SplitOp(self._op_id) self._op_id+=1 elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower(): - #if 'reshape' in grad_fn.name().lower(): - #print(grad_fn.__dir__()) - #print(grad_fn._saved_self_sizes) module = ops._ReshapeOp(self._op_id) self._op_id+=1 else: