You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Because the model has a hidden parameters in its forward(), tp.pruner.MetaPruner() crashes. Any ideas please?
MCVE:
import torch
import torch.quantization
import torch.nn as nn
class lstm_for_demonstration(nn.Module):
"""Elementary Long Short Term Memory style model which simply wraps ``nn.LSTM``
Not to be used for anything other than demonstration.
"""
def __init__(self, in_dim, out_dim, depth):
super(lstm_for_demonstration,self).__init__()
self.lstm = nn.LSTM(in_dim,out_dim,depth)
def forward(self, inputs, hidden):
out, hidden = self.lstm(inputs, hidden)
return out, hidden
#shape parameters
model_dimension=8
sequence_length=20
batch_size=1
lstm_depth=1
# random data for input
inputs = torch.randn(sequence_length, batch_size, model_dimension)
# hidden is actually is a tuple of the initial hidden state and the initial cell state
hidden = (torch.randn(lstm_depth, batch_size, model_dimension), torch.randn(lstm_depth, batch_size, model_dimension))
float_lstm = lstm_for_demonstration(model_dimension, model_dimension, lstm_depth)
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = float_lstm.eval()
#example_inputs = torch.randn(1, 3, 224, 224)
example_inputs = inputs# torch.randn(1,1,32,32)
# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.
# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
model,
example_inputs,
importance=imp,
pruning_ratio=0.5, # remove 50%
ignored_layers=ignored_layers,
)
# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
# finetune the pruned model here
# finetune(model)
# ...
Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch_pruning/dependency.py](https://localhost:8080/#) in _trace(self, model, example_inputs, forward_fn, output_transform)
780 try:
--> 781 out = model(*example_inputs)
782 except:
7 frames
TypeError: lstm_for_demonstration.forward() takes 3 positional arguments but 21 were given
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
TypeError: lstm_for_demonstration.forward() missing 1 required positional argument: 'hidden'
I also tried forward_fn parameter of tp.pruner.MetaPruner, but had same problem.
The text was updated successfully, but these errors were encountered:
Because the model has a
hidden
parameters in itsforward()
,tp.pruner.MetaPruner()
crashes. Any ideas please?MCVE:
Error:
I also tried
forward_fn
parameter oftp.pruner.MetaPruner
, but had same problem.The text was updated successfully, but these errors were encountered: