forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KJT methods test coverage with pt2 checks refactoring (pytorch#1988)
Summary: Pull Request resolved: pytorch#1988 Adding dynamo coverage for KJT methods: - permute - split - regroup_as_dict - getitem - todict split and getitem tests need additional checks (similar to pre slice check). Extracted those checks into pt2/utils, pt2_checks_tensor_slice. Reviewed By: PaulZhang12 Differential Revision: D57220897 fbshipit-source-id: 4a6314e6ddbf7b5e5d8ad25f72aa65906cff28d7
- Loading branch information
1 parent
b97efd5
commit 9fd4bc3
Showing
5 changed files
with
310 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import List | ||
|
||
import torch | ||
|
||
|
||
try: | ||
if torch.jit.is_scripting(): | ||
raise Exception() | ||
|
||
from torch.compiler import ( | ||
is_compiling as is_compiler_compiling, | ||
is_dynamo_compiling as is_torchdynamo_compiling, | ||
) | ||
|
||
def is_non_strict_exporting() -> bool: | ||
return not is_torchdynamo_compiling() and is_compiler_compiling() | ||
|
||
except Exception: | ||
# BC for torch versions without compiler and torch deploy path | ||
def is_torchdynamo_compiling() -> bool: # type: ignore[misc] | ||
return False | ||
|
||
def is_non_strict_exporting() -> bool: | ||
return False | ||
|
||
|
||
def pt2_checks_tensor_slice( | ||
tensor: torch.Tensor, start_offset: int, end_offset: int, dim: int = 0 | ||
) -> None: | ||
if torch.jit.is_scripting() or not is_torchdynamo_compiling(): | ||
return | ||
|
||
torch._check_is_size(start_offset) | ||
torch._check_is_size(end_offset) | ||
torch._check_is_size(end_offset - start_offset) | ||
torch._check(start_offset <= tensor.size(dim)) | ||
torch._check(end_offset <= tensor.size(dim)) | ||
torch._check(end_offset >= start_offset) | ||
|
||
|
||
def pt2_checks_all_is_size(list: List[int]) -> List[int]: | ||
if torch.jit.is_scripting() or not is_torchdynamo_compiling(): | ||
return list | ||
|
||
for i in list: | ||
torch._check_is_size(i) | ||
return list |
Oops, something went wrong.