Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use tp.pruner.MetaPruner with models that have hidden parameter in forward method? #407

Open
gsamaras opened this issue Jul 25, 2024 · 0 comments

Comments

@gsamaras
Copy link

Because the model has a hidden parameters in its forward(), tp.pruner.MetaPruner() crashes. Any ideas please?

MCVE:

import torch
import torch.quantization
import torch.nn as nn

class lstm_for_demonstration(nn.Module):
  """Elementary Long Short Term Memory style model which simply wraps ``nn.LSTM``
     Not to be used for anything other than demonstration.
  """
  def __init__(self, in_dim, out_dim, depth):
     super(lstm_for_demonstration,self).__init__()
     self.lstm = nn.LSTM(in_dim,out_dim,depth)

  def forward(self, inputs, hidden):
     out, hidden = self.lstm(inputs, hidden)
     return out, hidden

#shape parameters
model_dimension=8
sequence_length=20
batch_size=1
lstm_depth=1

# random data for input
inputs = torch.randn(sequence_length, batch_size, model_dimension)
# hidden is actually is a tuple of the initial hidden state and the initial cell state
hidden = (torch.randn(lstm_depth, batch_size, model_dimension), torch.randn(lstm_depth, batch_size, model_dimension))

float_lstm = lstm_for_demonstration(model_dimension, model_dimension, lstm_depth)

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = float_lstm.eval()

#example_inputs = torch.randn(1, 3, 224, 224)
example_inputs = inputs# torch.randn(1,1,32,32)

# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50%
    ignored_layers=ignored_layers,
)

# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
# finetune the pruned model here
# finetune(model)
# ...

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch_pruning/dependency.py](https://localhost:8080/#) in _trace(self, model, example_inputs, forward_fn, output_transform)
    780             try:
--> 781                 out = model(*example_inputs)
    782             except:

7 frames
TypeError: lstm_for_demonstration.forward() takes 3 positional arguments but 21 were given

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

TypeError: lstm_for_demonstration.forward() missing 1 required positional argument: 'hidden'

I also tried forward_fn parameter of tp.pruner.MetaPruner, but had same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant