diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 5708fd35f..6b530a320 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -153,4 +153,9 @@ def interpolate_handler( @implements(F.pixel_shuffle) def pixel_shuffle_handler(*args, **kwargs): - return quant_invariant_handler(F.pixel_shuffle_handler, *args, **kwargs) + return quant_invariant_handler(F.pixel_shuffle, *args, **kwargs) + + +@implements(F.pixel_unshuffle) +def pixel_unshuffle_handler(*args, **kwargs): + return quant_invariant_handler(F.pixel_unshuffle, *args, **kwargs) \ No newline at end of file