diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 43540cc72..ec7d6fac4 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -50,11 +50,12 @@ def torch_partial_deepcopy(model): def kthvalue( - x: torch.Tensor, - k: int, - dim: Optional[int] = None, - keepdim: bool = False, - out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None) -> torch.Tensor: + x: torch.Tensor, + k: int, + dim: Optional[int] = None, + keepdim: bool = False, + out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None +) -> Tuple[torch.Tensor, torch.LongTensor]: # As of torch 2.1, there is no kthvalue implementation: # - In CPU for float16 # - In GPU for bfloat16