Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip IntVarTensor for check_outputs and check_nan_and_inf routines #842

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/aitemplate/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,15 @@ def _process_src_ops(self, node: Tensor) -> None:
self.state_record.add(func._attrs["name"])
self._process_dims_for_op(func)

if self.debug_settings.check_all_nan_and_inf or node._attrs.get(
"check_nan_and_inf", False
):
if (
self.debug_settings.check_all_nan_and_inf
or node._attrs.get("check_nan_and_inf", False)
) and (not isinstance(node, IntVarTensor)):
self._append_check_nan_and_inf(node)
if self.debug_settings.check_all_outputs or node._attrs.get(
"check_outputs", False
):
if (
self.debug_settings.check_all_outputs
or node._attrs.get("check_outputs", False)
) and (not isinstance(node, IntVarTensor)):
self._append_check_outputs(node)

def _append_check_nan_and_inf(self, node: Tensor):
Expand Down
51 changes: 50 additions & 1 deletion tests/unittest/util/test_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.compiler.base import IntImm
from aitemplate.compiler.base import IntImm, IntVarTensor
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import get_random_torch_tensor
from aitemplate.utils import shape_utils
from aitemplate.utils.debug_settings import AITDebugSettings
from aitemplate.utils.torch_utils import string_to_torch_dtype

Expand Down Expand Up @@ -125,6 +127,53 @@ def test_outputs_bf16(capfd):
_test_outputs(True, True, "test_outputs_both_bfloat16", "bfloat16", capfd)


def _test_with_int_var_tensor(test_name, dtype):
target = detect_target()
batch_size = (3, 5)
x1_size = (2, 3)
X_shape = (32, 64)
b_dim = shape_utils.gen_int_var_min_max(batch_size, name="input_batch")
x1_dim = shape_utils.gen_int_var_min_max(x1_size, name="input_size")
X = Tensor(
shape=[b_dim, x1_dim, *X_shape],
dtype=dtype,
name="input_0",
is_input=True,
)

Y1 = ops.size()(X)
Y2 = ops.getitem()(Y1, 0)
Y3 = ops.getitem()(Y1, 1)
Y4 = ops.getitem()(Y1, 2)
Y5 = ops.getitem()(Y1, 3)
f1 = ops.int_elementwise(FuncEnum.MUL)(Y4, Y5)
f2 = IntVarTensor(IntImm(12))

Y = ops.reshape()(X, [Y2 * Y3 * f1 / f2, f2])
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
debug_settings = AITDebugSettings(
check_all_outputs=True, check_all_nan_and_inf=True
)
module = compile_model(Y, target, "./tmp", test_name, debug_settings=debug_settings)

for b, x1 in zip(batch_size, x1_size):
X_shape_pt = (b, x1, *X_shape)
X_pt = get_random_torch_tensor(X_shape_pt, dtype=dtype)
Y_pt = X_pt.reshape(
int(X_shape_pt[0] * X_shape_pt[1] * X_shape_pt[2] * X_shape_pt[3] / 12),
12,
)

y = torch.empty_like(Y_pt)
module.run_with_tensors([X_pt], [y])
assert torch.allclose(Y_pt, y, atol=1e-2, rtol=1e-2)


def test_int_var_tensor(capfd):
_test_with_int_var_tensor("test_outputs_int_var_tensor", "float16")


def _test_special_outputs(
check_tensor, check_all, test_name, capfd: pytest.CaptureFixture[str]
):
Expand Down
Loading