diff --git a/README.md b/README.md index de5f2718..d7fa4dda 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ For more technical details, please refer to our CVPR'23 paper: Please do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper. Or Join our Discord or WeChat group for a chat: * Discord: [link](https://discord.gg/Pvd6hbYXRs) - * WeChat Group (Group size exceeded 400): [QR Code](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3) + * WeChat Group [Group 1 (500/500, FULL)](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3), [Group-2](https://github.com/VainF/Torch-Pruning/assets/18592211/4e5f98e9-86b6-46bd-9e9f-3275c5ccc2f4) ## Table of Contents - [Installation](#installation) @@ -424,7 +424,7 @@ Please refer to [benchmarks](benchmarks) for more details. > **DeepCache: Accelerating Diffusion Models for Free** [[Project]](https://github.com/horseee/DeepCache) [[Arxiv]](https://arxiv.org/abs/2312.00858) > *Xinyin Ma, Gongfan Fang, and Xinchao Wang* -> Preprint 2023 +> CVPR 2024 > **0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284) > *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang* diff --git a/examples/sparse_training/benchmark_importance_criteria.py b/examples/sparse_training/benchmark_importance_criteria.py index c26b9e8d..c20cacfd 100644 --- a/examples/sparse_training/benchmark_importance_criteria.py +++ b/examples/sparse_training/benchmark_importance_criteria.py @@ -50,8 +50,8 @@ def validate_model(model, val_loader): # Importance criteria imp_dict = { - 'Group Hessian': tp.importance.OBDImportance(group_reduction='mean'), - 'Single-layer Hessian': tp.importance.OBDImportance(group_reduction='first'), + 'Group OBD': tp.importance.OBDImportance(group_reduction='mean'), + 'Single-layer OBD': tp.importance.OBDImportance(group_reduction='first'), 'Group Taylor': tp.importance.TaylorImportance(group_reduction='mean'), 'Single-layer Taylor': tp.importance.TaylorImportance(group_reduction='first'), diff --git a/examples/sparse_training/main.py b/examples/sparse_training/main.py index e448baf6..092728b8 100644 --- a/examples/sparse_training/main.py +++ b/examples/sparse_training/main.py @@ -18,6 +18,7 @@ parser.add_argument("--model", type=str, required=True) parser.add_argument("--verbose", action="store_true", default=False) parser.add_argument("--dataset", type=str, default="cifar100", choices=['cifar10', 'cifar100', 'modelnet40']) +parser.add_argument('--dataroot', default='data', help='path to your datasets') parser.add_argument("--batch-size", type=int, default=128) parser.add_argument("--total-epochs", type=int, default=100) parser.add_argument("--lr-decay-milestones", default="60,80", type=str, help="milestones for learning rate decay") @@ -25,6 +26,7 @@ parser.add_argument("--lr", default=0.01, type=float, help="learning rate") parser.add_argument("--restore", type=str, default=None) parser.add_argument('--output-dir', default='run', help='path where to save') +parser.add_argument("--finetune", action="store_true", default=False, help='whether finetune or not') # For pruning parser.add_argument("--method", type=str, default=None) @@ -46,17 +48,34 @@ args = parser.parse_args() -def progressive_pruning(pruner, model, speed_up, example_inputs): +def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None): model.eval() base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) current_speed_up = 1 while current_speed_up < speed_up: - pruner.step(interactive=False) + if args.method == "obdc": + model.zero_grad() + imp=pruner.importance + imp._prepare_model(model, pruner) + for k, (imgs, lbls) in enumerate(train_loader): + if k>=10: break + imgs = imgs.to(args.device) + lbls = lbls.to(args.device) + output = model(imgs) + sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1), + 1).squeeze().to(args.device) + loss_sample = F.cross_entropy(output, sampled_y) + loss_sample.backward() + imp.step() + pruner.step() + imp._rm_hooks(model) + imp._clear_buffer() + else: + pruner.step() pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) current_speed_up = float(base_ops) / pruned_ops if pruner.current_step == pruner.iterative_steps: break - #print(current_speed_up) return current_speed_up def eval(model, test_loader, device=None): @@ -169,9 +188,18 @@ def get_pruner(model, example_inputs): elif args.method == "l1": imp = tp.importance.MagnitudeImportance(p=1) pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning) + elif args.method == "l2": + imp = tp.importance.MagnitudeImportance(p=2) + pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning) + elif args.method == "fpgm": + imp = tp.importance.FPGMImportance(p=2) + pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning) + elif args.method == "obdc": + imp = tp.importance.OBDCImportance(group_reduction='mean', num_classes=args.num_classes) + pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning) elif args.method == "lamp": imp = tp.importance.LAMPImportance(p=2) - pruner_entry = partial(tp.pruner.BNScalePruner, global_pruning=args.global_pruning) + pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning) elif args.method == "slim": args.sparsity_learning = True imp = tp.importance.BNScaleImportance() @@ -242,7 +270,7 @@ def main(): # Model & Dataset args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes, train_dst, val_dst, input_size = registry.get_dataset( - args.dataset, data_root="data" + args.dataset, data_root=args.dataroot ) args.num_classes = num_classes model = registry.get_model(args.model, num_classes=num_classes, pretrained=True, target_dataset=args.dataset) @@ -315,7 +343,7 @@ def main(): ori_ops, ori_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) ori_acc, ori_val_loss = eval(model, test_loader, device=args.device) args.logger.info("Pruning...") - progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs) + progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs, train_loader=train_loader) del pruner # remove reference args.logger.info(model) pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) @@ -340,17 +368,18 @@ def main(): ) # 2. Finetuning - args.logger.info("Finetuning...") - train_model( - model, - epochs=args.total_epochs, - lr=args.lr, - lr_decay_milestones=args.lr_decay_milestones, - train_loader=train_loader, - test_loader=test_loader, - device=args.device, - save_state_dict_only=False, - ) + if args.finetune: + args.logger.info("Finetuning...") + train_model( + model, + epochs=args.total_epochs, + lr=args.lr, + lr_decay_milestones=args.lr_decay_milestones, + train_loader=train_loader, + test_loader=test_loader, + device=args.device, + save_state_dict_only=False, + ) elif args.mode == "test": model.eval() ops, params = tp.utils.count_ops_and_params( diff --git a/examples/sparse_training/tools/draw.py b/examples/sparse_training/tools/draw.py index 317705c1..aec999d8 100644 --- a/examples/sparse_training/tools/draw.py +++ b/examples/sparse_training/tools/draw.py @@ -7,8 +7,8 @@ plt.style.use('bmh') color_dict = { - 'Group Hessian': "C0", - 'Single-layer Hessian': "C0", + 'Group OBD': "C0", + 'Single-layer OBD': "C0", 'Random': "C1", diff --git a/examples/transformers/draw_acc_curve.py b/examples/transformers/draw_acc_curve.py index f5ec6dff..67c9cf86 100644 --- a/examples/transformers/draw_acc_curve.py +++ b/examples/transformers/draw_acc_curve.py @@ -13,7 +13,7 @@ def parse_acc_from_file(file_path): return acc log_dict = { - 'Hessian-uniform': 'output/vit_b_16_pruning_hessian_uniform/train.log', + 'OBD-uniform': 'output/vit_b_16_pruning_OBD_uniform/train.log', 'Taylor-uniform': 'output/vit_b_16_pruning_taylor_uniform/train.log', 'Taylor-bottleneck': 'output/vit_b_16_pruning_taylor_bottleneck/train.log', 'L1-uniform': 'output/vit_b_16_pruning_l1_uniform/train.log', diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py index 364a4ce1..6b062c90 100644 --- a/examples/transformers/prune_timm_vit.py +++ b/examples/transformers/prune_timm_vit.py @@ -20,7 +20,7 @@ def parse_args(): parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion') parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio') parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform') - parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'hessian']) + parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'OBD']) parser.add_argument('--test_accuracy', default=False, action='store_true', help='test accuracy') parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning') parser.add_argument('--prune_num_heads', default=False, action='store_true', help='global pruning') @@ -111,11 +111,11 @@ def main(): imp = tp.importance.GroupNormImportance(p=2) elif args.pruning_type == 'l1': imp = tp.importance.GroupNormImportance(p=1) - elif args.pruning_type == 'hessian': + elif args.pruning_type == 'OBD': imp = tp.importance.GroupOBDImportance() else: raise NotImplementedError - if args.pruning_type in ['taylor', 'hessian'] or args.test_accuracy: + if args.pruning_type in ['taylor', 'OBD'] or args.test_accuracy: train_loader, val_loader = prepare_imagenet(args.data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, use_imagenet_mean_std=args.use_imagenet_mean_std) # Load the model diff --git a/examples/transformers/readme.md b/examples/transformers/readme.md index d5e517df..6e098520 100644 --- a/examples/transformers/readme.md +++ b/examples/transformers/readme.md @@ -43,13 +43,13 @@ bash scripts/finetune_timm_vit_b_16_taylor_uniform.sh ``` Pruning results for ImageNet-21K-ft-1K (Timm): -| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group Hessian (Uniform) | +| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group OBD (Uniform) | | :-- | :--: | :--: | :--: | :--: | :--: | :--: | | **#Params** | 86.57 M | 88.22 M | 22.05 M | 22.05 M | 24.83 M | 22.05 M | | **MACs** | 17.59 G | 4.41 G | 4.61 G | 4.61 G | 4.62 G | 4.61 G | | **Acc @ Epoch 300** | 85.21 | 80.68 | 78.11 | 80.19 | 80.06 | 80.15 | | **Latency (Bs=1, A5000)** | 5.21 ms
+- 0.05 ms | 3.87 ms
+- 0.05 ms | 3.99 ms
+- 0.10 ms | 3.99 ms
+- 0.10 ms | 3.87 ms
+- 0.14 ms | 3.99 ms
+- 0.10 ms | -| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_hessian_uniform.pth) | +| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_OBD_uniform.pth) | *Notes:* * Uniform - We apply the same pruning ratio to all layers. diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh index 50012b2c..e210f1cd 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh @@ -1,5 +1,5 @@ torchrun --nproc_per_node=8 finetune.py \ - --model "output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth" \ + --model "output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth" \ --epochs 300 \ --batch-size 256 \ --opt adamw \ @@ -17,4 +17,4 @@ torchrun --nproc_per_node=8 finetune.py \ --ra-sampler \ --cutmix-alpha 1.0 \ --data-path "data/imagenet" \ - --output-dir output/vit_b_16_pruning_hessian_uniform \ No newline at end of file + --output-dir output/vit_b_16_pruning_OBD_uniform \ No newline at end of file diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh index 39e35671..bf3c7e55 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh @@ -1,10 +1,10 @@ python prune_timm_vit.py \ --model_name vit_base_patch16_224 \ - --pruning_type hessian \ + --pruning_type OBD \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ --test_accuracy \ --data_path data/imagenet \ --train_batch_size 64 \ --val_batch_size 64 \ - --save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \ \ No newline at end of file + --save_as output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth \ \ No newline at end of file diff --git a/tests/test_hessian_importance.py b/tests/test_hessian_importance.py index 2037ba83..b25887f8 100644 --- a/tests/test_hessian_importance.py +++ b/tests/test_hessian_importance.py @@ -2,7 +2,7 @@ from torchvision.models import resnet18 import torch_pruning as tp -def test_hessian(): +def test_OBD(): model = resnet18(pretrained=True) # Importance criteria @@ -52,4 +52,4 @@ def test_hessian(): # ... if __name__=="__main__": - test_hessian() \ No newline at end of file + test_OBD() \ No newline at end of file diff --git a/tests/test_regularization.py b/tests/test_regularization.py index d7703666..b9f48788 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -18,8 +18,15 @@ def test_pruner(): [tp.importance.GroupNormImportance, tp.pruner.GroupNormPruner], [tp.importance.BNScaleImportance, tp.pruner.BNScalePruner], [tp.importance.GroupNormImportance, tp.pruner.GrowingRegPruner], + [tp.importance.MagnitudeImportance, tp.pruner.GroupNormPruner], + [tp.importance.LAMPImportance, tp.pruner.GroupNormPruner], + [tp.importance.OBDCImportance, tp.pruner.GroupNormPruner], + [tp.importance.FPGMImportance, tp.pruner.GroupNormPruner], ]: - imp = imp_cls() + if imp_cls == tp.importance.OBDCImportance: + imp = imp_cls(num_classes=1000) + else: + imp = imp_cls() ignored_layer_outputs = [] # DO NOT prune the final classifier! for m in model.modules(): @@ -37,7 +44,12 @@ def test_pruner(): ) for i in range(iterative_steps): - model(example_inputs).sum().backward() + if isinstance(imp, tp.importance.OBDCImportance): + imp._prepare_model(model, pruner) + model(example_inputs).sum().backward() + imp.step() + else: + model(example_inputs).sum().backward() grad_dict = {} for p in model.parameters(): if p.grad is not None: @@ -53,6 +65,9 @@ def test_pruner(): print(name, "has no grad") for g in pruner.step(interactive=True): g.prune() + if isinstance(imp, tp.importance.OBDCImportance): + imp._rm_hooks(model) + imp._clear_buffer() if __name__ == "__main__": diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 2015ebad..9ea2335e 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -176,7 +176,7 @@ def prune(self, idxs=None, record_history=True): self._DG._param_to_name[pruned_parameter] = name self._DG.module2node[pruned_parameter] = self._DG.module2node.pop(old_parameter) self._DG.module2node[pruned_parameter].module = pruned_parameter - else: # prune nn.Module + else: dep(idxs) if record_history: diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 4cf3143d..d36ecc3a 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -2,6 +2,8 @@ import torch.nn as nn import typing, warnings +from torch_pruning.pruner.importance import OBDCImportance + from .scheduler import linear_scheduler from ..import function from ... import ops, dependency @@ -239,6 +241,8 @@ def step(self, interactive=False)-> typing.Union[typing.Generator, None]: else: for group in pruning_method(): group.prune() + # print("gg") + # exit(0) def estimate_importance(self, group) -> torch.Tensor: return self.importance(group) @@ -413,6 +417,9 @@ def prune_local(self) -> typing.Generator: if len(pruning_idxs)==0: continue pruning_idxs = torch.unique( torch.cat(pruning_idxs, 0) ).tolist() + if isinstance(self.importance, OBDCImportance): + self.importance.adjust_fisher(group, pruning_idxs) + group = self.DG.get_pruning_group( module, pruning_fn, pruning_idxs) @@ -542,6 +549,8 @@ def prune_global(self) -> typing.Generator: if len(pruning_indices)==0: continue pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist() + if isinstance(self.importance, OBDCImportance): + self.importance.adjust_fisher(group, pruning_indices) # create pruning group group = self.DG.get_pruning_group( module, pruning_fn, pruning_indices) diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 2465e9c4..c7f8c5d4 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -5,6 +5,12 @@ import typing from . import function from ..dependency import Group +from .._helpers import _FlattenIndexMapping +from .. import ops +import math +import numpy as np +from collections import OrderedDict +from ..utils.compute_mat_grad import ComputeMatGrad __all__ = [ # Base Class @@ -31,7 +37,6 @@ class Importance(abc.ABC): It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels. All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups. - Just ignore the ch_groups if you are not familar with grouping. Example: ```python @@ -90,13 +95,23 @@ def __init__(self, self.target_types = target_types self.bias = bias - def _lamp(self, imp): # Layer-adaptive Sparsity for the Magnitude-based Pruning - argsort_idx = torch.argsort(imp, dim=0, descending=True) - sorted_imp = imp[argsort_idx.tolist()] - cumsum_imp = torch.cumsum(sorted_imp, dim=0) - sorted_imp = sorted_imp / cumsum_imp - inversed_idx = torch.argsort(argsort_idx).tolist() # [0, 1, 2, 3, ..., ] - return sorted_imp[inversed_idx] + def _lamp(self, scores): # Layer-adaptive Sparsity for the Magnitude-based Pruning + """ + Normalizing scheme for LAMP. + """ + # sort scores in an ascending order + sorted_scores,sorted_idx = scores.view(-1).sort(descending=False) + # compute cumulative sum + scores_cumsum_temp = sorted_scores.cumsum(dim=0) + scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device) + scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1] + # normalize by cumulative sum + sorted_scores /= (scores.sum() - scores_cumsum) + # tidy up and output + new_scores = torch.zeros(scores_cumsum.shape,device=scores.device) + new_scores[sorted_idx] = sorted_scores + + return new_scores.view(scores.shape) def _normalize(self, group_importance, normalizer): if normalizer is None: @@ -295,6 +310,76 @@ def __init__(self, p=2, group_reduction="mean", normalizer='lamp', bias=False): assert normalizer == 'lamp' super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer, bias=bias) + +class FPGMImportance(GroupNormImportance): + """Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration, + http://openaccess.thecvf.com/content_CVPR_2019/papers/He_Filter_Pruning_via_Geometric_Median_for_Deep_Convolutional_Neural_Networks_CVPR_2019_paper.pdf + """ + + def __init__(self, p=2, group_reduction="mean", normalizer='mean', bias=False): + super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer, bias=bias) + + @torch.no_grad() + def __call__(self, group, **kwargs): + group_imp = [] + group_idxs = [] + # Iterate over all groups and estimate group importance + for i, (dep, idxs) in enumerate(group): + layer = dep.layer + prune_fn = dep.pruning_fn + root_idxs = group[i].root_idxs + if not isinstance(layer, tuple(self.target_types)): + continue + #################### + # Conv/Linear Output + #################### + if prune_fn in [ + function.prune_conv_out_channels, + function.prune_linear_out_channels, + ]: + if hasattr(layer, "transposed") and layer.transposed: + w = layer.weight.data.transpose(1, 0)[idxs].flatten(1) + else: + w = layer.weight.data[idxs].flatten(1) + local_imp = w.abs().pow(self.p) + # calculate the euclidean distance as similarity + similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0) + similar_sum = torch.sum(torch.abs(similar_matrix), dim=0) + group_imp.append(similar_sum) + group_idxs.append(root_idxs) + + #################### + # Conv/Linear Input + #################### + elif prune_fn in [ + function.prune_conv_in_channels, + function.prune_linear_in_channels, + ]: + if hasattr(layer, "transposed") and layer.transposed: + w = (layer.weight.data).flatten(1) + else: + w = (layer.weight.data).transpose(0, 1).flatten(1) + + local_imp = w.abs().pow(self.p) + + # repeat importance for group convolutions + if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1: + local_imp = local_imp.repeat(layer.groups) + local_imp = local_imp[idxs] + similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0) + similar_sum = torch.sum(torch.abs(similar_matrix), dim=0) + group_imp.append(similar_sum) + group_idxs.append(root_idxs) + + # FPGMImportance should not care about BatchNorm and LayerNorm + + if len(group_imp) == 0: # skip groups without parameterized layers + return None + + group_imp = self._reduce(group_imp, group_idxs) + group_imp = self._normalize(group_imp, self.normalizer) + return group_imp + class RandomImportance(Importance): """ Random importance estimator Example: @@ -450,6 +535,125 @@ def __call__(self, group): group_imp = self._normalize(group_imp, self.normalizer) return group_imp +class OBDCImportance(GroupNormImportance): + """EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis: + http://proceedings.mlr.press/v97/wang19g/wang19g.pdf + """ + def __init__(self, + group_reduction:str="mean", + normalizer:str='mean', + bias=False, + target_types:list=[nn.modules.conv._ConvNd, nn.Linear], + num_classes=100): + self.group_reduction = group_reduction + self.normalizer = normalizer + self.target_types = target_types + self.bias = bias + self.A, self.DS = {}, {} + self.Fisher = {} + self.MatGradHandler = ComputeMatGrad() + self.steps = 0 + self.eps = 1e-10 + self.modules = [] + self.num_classes = num_classes + self.known_modules = {'Linear', 'Conv2d'} + + def step(self): + with torch.no_grad(): + for m in self.modules: + A, DS = self.A[m], self.DS[m] + grad_mat = self.MatGradHandler(A, DS, m) + grad_mat *= DS.size(0) + if self.steps == 0: + self.Fisher[m] = grad_mat.new(grad_mat.size()[1:]).fill_(0) + self.Fisher[m] += (grad_mat.pow_(2)).sum(0) + self.A[m] = None + self.DS[m] = None + self.steps += 1 + + def adjust_fisher(self, group, idxs): + for i, (dep, id) in enumerate(group): + layer = dep.target.module + if layer in self.modules: + if layer.weight.grad is not None: + shape = layer.weight.shape + if isinstance(layer, nn.modules.conv._ConvNd): + kernel_size = shape[2]*shape[3] + else: + kernel_size = 1 + indices_to_keep = list(range(self.Fisher[layer].shape[1])) + for idx in idxs: + indices_to_keep = [i for i in indices_to_keep if not (idx*kernel_size <= i < (idx+1)*kernel_size)] + self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.LongTensor(indices_to_keep).to(self.Fisher[layer].device)) + + + def _rm_hooks(self, model): + for m in self.modules: + for h in self._hooks: + h.remove() + + def _save_input(self, module, input): + self.A[module] = input[0].data + + def _save_grad_output(self, module, grad_input, grad_output): + self.DS[module] = grad_output[0].data + + def _prepare_model(self, model, pruner): + self._hooks = [] + for group in pruner.DG.get_all_groups(ignored_layer_inputs=pruner.ignored_layer_inputs, ignored_layer_outputs=pruner.ignored_layer_outputs, target_layers=pruner.target_layers, target_layer_types=pruner.target_layer_types): + group = pruner._downstream_node_as_root_if_attention(group) + for i, (dep, idxs) in enumerate(group): + layer = dep.target.module + if isinstance(layer, tuple(self.target_types)) and dep.handler in [ + function.prune_conv_out_channels, + function.prune_linear_out_channels, + ]: + self.modules.append(layer) + self._hooks.append(layer.register_forward_pre_hook(self._save_input)) + self._hooks.append(layer.register_backward_hook(self._save_grad_output)) + + def _clear_buffer(self): + self.Fisher = {} + self.modules = [] + self.steps = 0 + + @torch.no_grad() + def __call__(self, group): + group_imp = [] + group_idxs = [] + for i, (dep, idxs) in enumerate(group): + idxs.sort() + layer = dep.target.module + prune_fn = dep.handler + root_idxs = group[i].root_idxs + if not isinstance(layer, tuple(self.target_types)) or (isinstance(layer, torch.nn.Linear) and layer.out_features == self.num_classes): + continue + F_diag = (self.Fisher[layer] / self.steps + self.eps) + if prune_fn in [ + function.prune_conv_out_channels, + function.prune_linear_out_channels, + ]: + if layer.weight.grad is not None: + if hasattr(layer, "transposed") and layer.transposed: + w = layer.weight.data.transpose(1, 0)[idxs].flatten(1) + else: + w = layer.weight.data[idxs].flatten(1) + local_imp = (w ** 2 * F_diag).sum(1) + group_imp.append(local_imp) + group_idxs.append(root_idxs) + + if self.bias and layer.bias is not None and layer.bias.grad is not None: + b = layer.bias.data[idxs] + local_imp = (b ** 2 * F_diag).sum(1) + group_imp.append(local_imp) + group_idxs.append(root_idxs) + + if len(group_imp) == 0: # skip groups without parameterized layers + return None + group_imp = self._reduce(group_imp, group_idxs) + group_imp = self._normalize(group_imp, self.normalizer) + return group_imp + class GroupOBDImportance(GroupNormImportance): """Grouped Optimal Brain Damage: https://proceedings.neurips.cc/paper/1989/hash/6c9882bbac1c7093bd25041881277658-Abstract.html @@ -615,4 +819,5 @@ class TaylorImportance(GroupTaylorImportance): pass class OBDImportance(GroupOBDImportance): - pass \ No newline at end of file + pass + diff --git a/torch_pruning/utils/compute_mat_grad.py b/torch_pruning/utils/compute_mat_grad.py new file mode 100644 index 00000000..e89051ed --- /dev/null +++ b/torch_pruning/utils/compute_mat_grad.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def _extract_patches(x, kernel_size, stride, padding): + """ + :param x: The input feature maps. (batch_size, in_c, h, w) + :param kernel_size: the kernel size of the conv filter (tuple of two elements) + :param stride: the stride of conv operation (tuple of two elements) + :param padding: number of paddings. be a tuple of two elements + :return: (batch_size, out_h, out_w, in_c*kh*kw) + """ + if padding[0] + padding[1] > 0: + x = F.pad(x, (padding[1], padding[1], padding[0], + padding[0])).data # Actually check dims + x = x.unfold(2, kernel_size[0], stride[0]) + x = x.unfold(3, kernel_size[1], stride[1]) + x = x.transpose_(1, 2).transpose_(2, 3).contiguous() + x = x.view( + x.size(0), x.size(1), x.size(2), + x.size(3) * x.size(4) * x.size(5)) + return x + + + +def try_contiguous(x): + if not x.is_contiguous(): + x = x.contiguous() + + return x + + +class ComputeMatGrad: + + @classmethod + def __call__(cls, input, grad_output, layer): + if isinstance(layer, nn.Linear): + grad = cls.linear(input, grad_output, layer) + elif isinstance(layer, nn.Conv2d): + grad = cls.conv2d(input, grad_output, layer) + else: + raise NotImplementedError + return grad + + @staticmethod + def linear(input, grad_output, layer): + """ + :param input: batch_size * input_dim + :param grad_output: batch_size * output_dim + :param layer: [nn.module] output_dim * input_dim + :return: batch_size * output_dim * (input_dim + [1 if with bias]) + """ + with torch.no_grad(): + if layer.bias is not None: + input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) + input = input.unsqueeze(1) + grad_output = grad_output.unsqueeze(2) + grad = torch.bmm(grad_output, input) + return grad + + @staticmethod + def conv2d(input, grad_output, layer): + """ + :param input: batch_size * in_c * in_h * in_w + :param grad_output: batch_size * out_c * h * w + :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias]) + :return: + """ + with torch.no_grad(): + input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding) + input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw + grad_output = grad_output.transpose(1, 2).transpose(2, 3) + grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1)) + # b * hw * out_c + if layer.bias is not None: + input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) + input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw + grad = torch.einsum('abm,abn->amn', (grad_output, input)) + return grad \ No newline at end of file