Skip to content

Commit

Permalink
Merge pull request #225 from VainF/v1.2
Browse files Browse the repository at this point in the history
V1.2.1
  • Loading branch information
VainF authored Jul 26, 2023
2 parents 9f7cfe0 + c1a5bb2 commit 83f88f4
Show file tree
Hide file tree
Showing 17 changed files with 755 additions and 116 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.8 %20%7C%201.12 %20%7C%202.0-673ab7.svg" alt="Tested PyTorch Versions"></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.2.0-3f51b5.svg" alt="Latest Version"></a>
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.2.1-3f51b5.svg" alt="Latest Version"></a>
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Expand All @@ -21,7 +21,7 @@

Torch-Pruning (TP) is a library for structural pruning with the following features:

* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [ViT](examples/torchvision_models/), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters physically. Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for a quick start.
* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [ViT](examples/hf_transformers/), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters physically. Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for a quick start.
* **[Performance Benchmark](benchmarks)**: Reproduce the our results in the DepGraph paper.
* **[Tutorials and Documents](https://github.com/VainF/Torch-Pruning/wiki)** are available at the GitHub Wiki.

Expand Down
37 changes: 22 additions & 15 deletions benchmarks/main_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--target-flops", type=float, default=2.0, help="GFLOPs of pruned model")
parser.add_argument("--soft-keeping-ratio", type=float, default=0.0)
parser.add_argument("--reg", type=float, default=1e-4)
parser.add_argument("--delta_reg", type=float, default=1e-4)
parser.add_argument("--max-ch-sparsity", default=1.0, type=float, help="maximum channel sparsity")
parser.add_argument("--sl-epochs", type=int, default=None)
parser.add_argument("--sl-resume", type=str, default=None)
Expand Down Expand Up @@ -131,6 +132,10 @@ def get_pruner(model, example_inputs, args):
elif args.method == "group_norm":
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning)
elif args.method == "group_greg":
sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning)
elif args.method == "group_sl":
sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
Expand Down Expand Up @@ -163,7 +168,7 @@ def get_pruner(model, example_inputs, args):



def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, regularizer=None, recover=None):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, pruner=None, recover=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand All @@ -180,17 +185,17 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
if regularizer:
if pruner:
scaler.unscale_(optimizer)
regularizer(model)
pruner.regularize(model)
#if recover:
# recover(model.module)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if regularizer:
regularizer(model)
if pruner is not None:
pruner.regularize(model)
if recover:
recover(model.module)
if args.clip_grad_norm is not None:
Expand All @@ -202,14 +207,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))


if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
pruner.update_reg()

def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
Expand Down Expand Up @@ -381,7 +388,7 @@ def collate_fn(batch):
train(model, args.sl_epochs,
lr=args.sl_lr, lr_step_size=args.sl_lr_step_size, lr_warmup_epochs=args.sl_lr_warmup_epochs,
train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test,
device=device, args=args, regularizer=pruner.regularize, state_dict_only=True)
device=device, args=args, pruner=pruner, state_dict_only=True)
#model.load_state_dict( torch.load('regularized_{:.4f}_best.pth'.format(args.reg), map_location='cpu')['model'] )
#utils.save_on_master(
# model_without_ddp.state_dict(),
Expand All @@ -403,14 +410,14 @@ def collate_fn(batch):
train(model, args.epochs,
lr=args.lr, lr_step_size=args.lr_step_size, lr_warmup_epochs=args.lr_warmup_epochs,
train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test,
device=device, args=args, regularizer=None, state_dict_only=(not args.prune))
device=device, args=args, pruner=None, state_dict_only=(not args.prune))

def train(
model,
epochs,
lr, lr_step_size, lr_warmup_epochs,
train_sampler, data_loader, data_loader_test,
device, args, regularizer=None, state_dict_only=True, recover=None):
device, args, pruner=None, state_dict_only=True, recover=None):

model.to(device)
if args.distributed and args.sync_bn:
Expand All @@ -421,9 +428,9 @@ def train(
else:
criterion = nn.CrossEntropyLoss()

weight_decay = args.weight_decay if regularizer is None else 0
bias_weight_decay = args.bias_weight_decay if regularizer is None else 0
norm_weight_decay = args.norm_weight_decay if regularizer is None else 0
weight_decay = args.weight_decay if pruner is None else 0
bias_weight_decay = args.bias_weight_decay if pruner is None else 0
norm_weight_decay = args.norm_weight_decay if pruner is None else 0

custom_keys_weight_decay = []
if bias_weight_decay is not None:
Expand Down Expand Up @@ -534,11 +541,11 @@ def train(

start_time = time.time()
best_acc = 0
prefix = '' if regularizer is None else 'regularized_{:e}_'.format(args.reg)
prefix = '' if pruner is None else 'regularized_{:e}_'.format(args.reg)
for epoch in range(args.start_epoch, epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, regularizer, recover=recover)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, pruner, recover=recover)
lr_scheduler.step()
acc = evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
Expand Down
Loading

0 comments on commit 83f88f4

Please sign in to comment.