diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 06f36e9..5167c3c 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -1193,7 +1193,10 @@ def _update_expand_index_mapping(self, node: Node): return if out_channels is not None: # =0 if there is a residual connection to model inputs break - assert hasattr(node.grad_fn, '_saved_self_sym_sizes'), "New version of PyTorch is required for expand operation." + if not hasattr(node.grad_fn, '_saved_self_sym_sizes'): + #warnings.warn("Expand operation detected but the shape information is not available") + return + if len(node.grad_fn._saved_self_sym_sizes) != 5: return