Skip to content

Commit

Permalink
Merge OBDC & FPGM
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Mar 9, 2024
2 parents 35c11f3 + 319656b commit 01893e3
Show file tree
Hide file tree
Showing 15 changed files with 384 additions and 47 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ For more technical details, please refer to our CVPR'23 paper:
Please do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper.
Or Join our Discord or WeChat group for a chat:
* Discord: [link](https://discord.gg/Pvd6hbYXRs)
* WeChat Group (Group size exceeded 400): [QR Code](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3)
* WeChat Group [Group 1 (500/500, FULL)](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3), [Group-2](https://github.com/VainF/Torch-Pruning/assets/18592211/4e5f98e9-86b6-46bd-9e9f-3275c5ccc2f4)

## Table of Contents
- [Installation](#installation)
Expand Down Expand Up @@ -424,7 +424,7 @@ Please refer to [benchmarks](benchmarks) for more details.
> **DeepCache: Accelerating Diffusion Models for Free** [[Project]](https://github.com/horseee/DeepCache) [[Arxiv]](https://arxiv.org/abs/2312.00858)
> *Xinyin Ma, Gongfan Fang, and Xinchao Wang*
> Preprint 2023
> CVPR 2024
> **0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)
> *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang*
Expand Down
4 changes: 2 additions & 2 deletions examples/sparse_training/benchmark_importance_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def validate_model(model, val_loader):

# Importance criteria
imp_dict = {
'Group Hessian': tp.importance.OBDImportance(group_reduction='mean'),
'Single-layer Hessian': tp.importance.OBDImportance(group_reduction='first'),
'Group OBD': tp.importance.OBDImportance(group_reduction='mean'),
'Single-layer OBD': tp.importance.OBDImportance(group_reduction='first'),

'Group Taylor': tp.importance.TaylorImportance(group_reduction='mean'),
'Single-layer Taylor': tp.importance.TaylorImportance(group_reduction='first'),
Expand Down
63 changes: 46 additions & 17 deletions examples/sparse_training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--dataset", type=str, default="cifar100", choices=['cifar10', 'cifar100', 'modelnet40'])
parser.add_argument('--dataroot', default='data', help='path to your datasets')
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--total-epochs", type=int, default=100)
parser.add_argument("--lr-decay-milestones", default="60,80", type=str, help="milestones for learning rate decay")
parser.add_argument("--lr-decay-gamma", default=0.1, type=float)
parser.add_argument("--lr", default=0.01, type=float, help="learning rate")
parser.add_argument("--restore", type=str, default=None)
parser.add_argument('--output-dir', default='run', help='path where to save')
parser.add_argument("--finetune", action="store_true", default=False, help='whether finetune or not')

# For pruning
parser.add_argument("--method", type=str, default=None)
Expand All @@ -46,17 +48,34 @@

args = parser.parse_args()

def progressive_pruning(pruner, model, speed_up, example_inputs):
def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
model.eval()
base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = 1
while current_speed_up < speed_up:
pruner.step(interactive=False)
if args.method == "obdc":
model.zero_grad()
imp=pruner.importance
imp._prepare_model(model, pruner)
for k, (imgs, lbls) in enumerate(train_loader):
if k>=10: break
imgs = imgs.to(args.device)
lbls = lbls.to(args.device)
output = model(imgs)
sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1),
1).squeeze().to(args.device)
loss_sample = F.cross_entropy(output, sampled_y)
loss_sample.backward()
imp.step()
pruner.step()
imp._rm_hooks(model)
imp._clear_buffer()
else:
pruner.step()
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = float(base_ops) / pruned_ops
if pruner.current_step == pruner.iterative_steps:
break
#print(current_speed_up)
return current_speed_up

def eval(model, test_loader, device=None):
Expand Down Expand Up @@ -169,9 +188,18 @@ def get_pruner(model, example_inputs):
elif args.method == "l1":
imp = tp.importance.MagnitudeImportance(p=1)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "l2":
imp = tp.importance.MagnitudeImportance(p=2)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "fpgm":
imp = tp.importance.FPGMImportance(p=2)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "obdc":
imp = tp.importance.OBDCImportance(group_reduction='mean', num_classes=args.num_classes)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "lamp":
imp = tp.importance.LAMPImportance(p=2)
pruner_entry = partial(tp.pruner.BNScalePruner, global_pruning=args.global_pruning)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "slim":
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
Expand Down Expand Up @@ -242,7 +270,7 @@ def main():
# Model & Dataset
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes, train_dst, val_dst, input_size = registry.get_dataset(
args.dataset, data_root="data"
args.dataset, data_root=args.dataroot
)
args.num_classes = num_classes
model = registry.get_model(args.model, num_classes=num_classes, pretrained=True, target_dataset=args.dataset)
Expand Down Expand Up @@ -315,7 +343,7 @@ def main():
ori_ops, ori_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
ori_acc, ori_val_loss = eval(model, test_loader, device=args.device)
args.logger.info("Pruning...")
progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs)
progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs, train_loader=train_loader)
del pruner # remove reference
args.logger.info(model)
pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
Expand All @@ -340,17 +368,18 @@ def main():
)

# 2. Finetuning
args.logger.info("Finetuning...")
train_model(
model,
epochs=args.total_epochs,
lr=args.lr,
lr_decay_milestones=args.lr_decay_milestones,
train_loader=train_loader,
test_loader=test_loader,
device=args.device,
save_state_dict_only=False,
)
if args.finetune:
args.logger.info("Finetuning...")
train_model(
model,
epochs=args.total_epochs,
lr=args.lr,
lr_decay_milestones=args.lr_decay_milestones,
train_loader=train_loader,
test_loader=test_loader,
device=args.device,
save_state_dict_only=False,
)
elif args.mode == "test":
model.eval()
ops, params = tp.utils.count_ops_and_params(
Expand Down
4 changes: 2 additions & 2 deletions examples/sparse_training/tools/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
plt.style.use('bmh')

color_dict = {
'Group Hessian': "C0",
'Single-layer Hessian': "C0",
'Group OBD': "C0",
'Single-layer OBD': "C0",

'Random': "C1",

Expand Down
2 changes: 1 addition & 1 deletion examples/transformers/draw_acc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def parse_acc_from_file(file_path):
return acc

log_dict = {
'Hessian-uniform': 'output/vit_b_16_pruning_hessian_uniform/train.log',
'OBD-uniform': 'output/vit_b_16_pruning_OBD_uniform/train.log',
'Taylor-uniform': 'output/vit_b_16_pruning_taylor_uniform/train.log',
'Taylor-bottleneck': 'output/vit_b_16_pruning_taylor_bottleneck/train.log',
'L1-uniform': 'output/vit_b_16_pruning_l1_uniform/train.log',
Expand Down
6 changes: 3 additions & 3 deletions examples/transformers/prune_timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def parse_args():
parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio')
parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform')
parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'hessian'])
parser.add_argument('--pruning_type', default='l1', type=str, help='pruning type', choices=['random', 'taylor', 'l2', 'l1', 'OBD'])
parser.add_argument('--test_accuracy', default=False, action='store_true', help='test accuracy')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--prune_num_heads', default=False, action='store_true', help='global pruning')
Expand Down Expand Up @@ -111,11 +111,11 @@ def main():
imp = tp.importance.GroupNormImportance(p=2)
elif args.pruning_type == 'l1':
imp = tp.importance.GroupNormImportance(p=1)
elif args.pruning_type == 'hessian':
elif args.pruning_type == 'OBD':
imp = tp.importance.GroupOBDImportance()
else: raise NotImplementedError

if args.pruning_type in ['taylor', 'hessian'] or args.test_accuracy:
if args.pruning_type in ['taylor', 'OBD'] or args.test_accuracy:
train_loader, val_loader = prepare_imagenet(args.data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, use_imagenet_mean_std=args.use_imagenet_mean_std)

# Load the model
Expand Down
4 changes: 2 additions & 2 deletions examples/transformers/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ bash scripts/finetune_timm_vit_b_16_taylor_uniform.sh
```
Pruning results for ImageNet-21K-ft-1K (Timm):

| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group Hessian (Uniform) |
| | ViT-B/16 (Timm) | ViT_B/32 (Timm) | Group L2 (Uniform) | Group Taylor (Uniform) | Group Taylor (Bottleneck) | Group OBD (Uniform) |
| :-- | :--: | :--: | :--: | :--: | :--: | :--: |
| **#Params** | 86.57 M | 88.22 M | 22.05 M | 22.05 M | 24.83 M | 22.05 M |
| **MACs** | 17.59 G | 4.41 G | 4.61 G | 4.61 G | 4.62 G | 4.61 G |
| **Acc @ Epoch 300** | 85.21 | 80.68 | 78.11 | 80.19 | 80.06 | 80.15 |
| **Latency (Bs=1, A5000)** | 5.21 ms <br> +- 0.05 ms | 3.87 ms <br> +- 0.05 ms | 3.99 ms <br> +- 0.10 ms | 3.99 ms <br> +- 0.10 ms | 3.87 ms <br> +- 0.14 ms | 3.99 ms <br> +- 0.10 ms |
| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_hessian_uniform.pth) |
| **Checkpoints** | - | - | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_l2_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_uniform.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_taylor_bottleneck.pth) | [ckpt](https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pruning_OBD_uniform.pth) |

*Notes:*
* Uniform - We apply the same pruning ratio to all layers.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torchrun --nproc_per_node=8 finetune.py \
--model "output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth" \
--model "output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth" \
--epochs 300 \
--batch-size 256 \
--opt adamw \
Expand All @@ -17,4 +17,4 @@ torchrun --nproc_per_node=8 finetune.py \
--ra-sampler \
--cutmix-alpha 1.0 \
--data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_hessian_uniform
--output-dir output/vit_b_16_pruning_OBD_uniform
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
python prune_timm_vit.py \
--model_name vit_base_patch16_224 \
--pruning_type hessian \
--pruning_type OBD \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
--test_accuracy \
--data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \
--save_as output/pruned/vit_base_patch16_224_pruned_OBD_uniform.pth \
4 changes: 2 additions & 2 deletions tests/test_hessian_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision.models import resnet18
import torch_pruning as tp

def test_hessian():
def test_OBD():
model = resnet18(pretrained=True)

# Importance criteria
Expand Down Expand Up @@ -52,4 +52,4 @@ def test_hessian():
# ...

if __name__=="__main__":
test_hessian()
test_OBD()
19 changes: 17 additions & 2 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@ def test_pruner():
[tp.importance.GroupNormImportance, tp.pruner.GroupNormPruner],
[tp.importance.BNScaleImportance, tp.pruner.BNScalePruner],
[tp.importance.GroupNormImportance, tp.pruner.GrowingRegPruner],
[tp.importance.MagnitudeImportance, tp.pruner.GroupNormPruner],
[tp.importance.LAMPImportance, tp.pruner.GroupNormPruner],
[tp.importance.OBDCImportance, tp.pruner.GroupNormPruner],
[tp.importance.FPGMImportance, tp.pruner.GroupNormPruner],
]:
imp = imp_cls()
if imp_cls == tp.importance.OBDCImportance:
imp = imp_cls(num_classes=1000)
else:
imp = imp_cls()
ignored_layer_outputs = []
# DO NOT prune the final classifier!
for m in model.modules():
Expand All @@ -37,7 +44,12 @@ def test_pruner():
)

for i in range(iterative_steps):
model(example_inputs).sum().backward()
if isinstance(imp, tp.importance.OBDCImportance):
imp._prepare_model(model, pruner)
model(example_inputs).sum().backward()
imp.step()
else:
model(example_inputs).sum().backward()
grad_dict = {}
for p in model.parameters():
if p.grad is not None:
Expand All @@ -53,6 +65,9 @@ def test_pruner():
print(name, "has no grad")
for g in pruner.step(interactive=True):
g.prune()
if isinstance(imp, tp.importance.OBDCImportance):
imp._rm_hooks(model)
imp._clear_buffer()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def prune(self, idxs=None, record_history=True):
self._DG._param_to_name[pruned_parameter] = name
self._DG.module2node[pruned_parameter] = self._DG.module2node.pop(old_parameter)
self._DG.module2node[pruned_parameter].module = pruned_parameter
else: # prune nn.Module
else:
dep(idxs)

if record_history:
Expand Down
9 changes: 9 additions & 0 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import typing, warnings

from torch_pruning.pruner.importance import OBDCImportance

from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency
Expand Down Expand Up @@ -239,6 +241,8 @@ def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
else:
for group in pruning_method():
group.prune()
# print("gg")
# exit(0)

def estimate_importance(self, group) -> torch.Tensor:
return self.importance(group)
Expand Down Expand Up @@ -413,6 +417,9 @@ def prune_local(self) -> typing.Generator:

if len(pruning_idxs)==0: continue
pruning_idxs = torch.unique( torch.cat(pruning_idxs, 0) ).tolist()
if isinstance(self.importance, OBDCImportance):
self.importance.adjust_fisher(group, pruning_idxs)

group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs)

Expand Down Expand Up @@ -542,6 +549,8 @@ def prune_global(self) -> typing.Generator:

if len(pruning_indices)==0: continue
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()
if isinstance(self.importance, OBDCImportance):
self.importance.adjust_fisher(group, pruning_indices)
# create pruning group
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices)
Expand Down
Loading

0 comments on commit 01893e3

Please sign in to comment.