Skip to content

Commit

Permalink
More typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent edba342 commit d5f6a3c
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def torch_partial_deepcopy(model):


def kthvalue(
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None) -> torch.Tensor:
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None
) -> Tuple[torch.Tensor, torch.LongTensor]:
# As of torch 2.1, there is no kthvalue implementation:
# - In CPU for float16
# - In GPU for bfloat16
Expand Down

0 comments on commit d5f6a3c

Please sign in to comment.