Skip to content

Commit

Permalink
KJT methods test coverage with pt2 checks refactoring (pytorch#1988)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1988

Adding dynamo coverage for KJT methods:
- permute
- split
- regroup_as_dict
- getitem
- todict

split and getitem tests need additional checks (similar to pre slice check).

Extracted those checks into pt2/utils, pt2_checks_tensor_slice.

Reviewed By: PaulZhang12

Differential Revision: D57220897

fbshipit-source-id: 4a6314e6ddbf7b5e5d8ad25f72aa65906cff28d7
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 13, 2024
1 parent b97efd5 commit 9fd4bc3
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 54 deletions.
35 changes: 35 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,41 @@ def prep_inputs(
return inputs


class KJTInputExportWrapperWithStrides(torch.nn.Module):
"""
Version of KJTInputExportWrapper with stride_per_key_per_rank_tensor argument for VB path.
"""

def __init__(
self,
module_kjt_input: torch.nn.Module,
kjt_keys: List[str],
) -> None:
super().__init__()
self._module_kjt_input = module_kjt_input
self._kjt_keys = kjt_keys

# pyre-ignore
def forward(
self,
values: torch.Tensor,
lengths: torch.Tensor,
stride_per_key_per_rank: Optional[List[List[int]]],
# pyre-ignore
*args,
# pyre-ignore
**kwargs,
):
kjt = KeyedJaggedTensor(
keys=self._kjt_keys,
values=values,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
)
output = self._module_kjt_input(kjt, *args, **kwargs)
return [leaf for leaf in pytree.tree_leaves(output) if leaf is not None]


def prep_inputs_multiprocess(
model_info: TestModelInfo, world_size: int, batch_size: int = 1, count: int = 5
) -> List[Tuple[ModelInput, List[ModelInput]]]:
Expand Down
204 changes: 178 additions & 26 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@

# pyre-ignore-all-errors

import itertools
import sys
import unittest
from typing import Any, List, Tuple
from enum import auto, Enum
from typing import Any, Dict, List, Tuple

import torch
from hypothesis import given, settings, strategies as st
from torchrec.distributed.test_utils.infer_utils import (
KJTInputExportDynamicShapeWrapper,
KJTInputExportWrapperWithStrides,
TestQuantFPEBCSharder,
)
from torchrec.pt2.utils import kjt_for_pt2_tracing
from torchrec.sparse.jagged_tensor import KeyedTensor

try:
# pyre-ignore
Expand All @@ -40,7 +46,11 @@
TestQuantEBCSharder,
)
from torchrec.distributed.types import BoundsCheckMode, ShardingEnv, ShardingType
from torchrec.sparse.jagged_tensor import ComputeKJTToJTDict, KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import (
ComputeKJTToJTDict,
JaggedTensor,
KeyedJaggedTensor,
)


def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor:
Expand All @@ -57,6 +67,14 @@ def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor:
return kjt


def kjt_module_kjt_inputs_with_strides(kjt: KeyedJaggedTensor) -> Tuple:
return (
kjt._values,
kjt._lengths,
kjt._stride_per_key_per_rank,
)


def _sharded_quant_ebc_model(
local_device: str = "cuda",
compute_device: str = "cuda",
Expand Down Expand Up @@ -126,6 +144,11 @@ def _sharded_quant_ebc_model(
return model, input_kjts


class _TestType(Enum):
EXPORT = auto()
DYNAMO_COMPILE = auto()


class TestPt2(unittest.TestCase):
def _test_kjt_input_module(
self,
Expand Down Expand Up @@ -176,36 +199,137 @@ def _test_kjt_input_module(
pt2_ir_output = pt2_ir.module()(*em_inputs)
assert_close(eager_output, pt2_ir_output)

def test_kjt_split(self) -> None:
# Separate test for Dynamo, as it fallbacks on VB path.
# Torchrec has lazy init modules, depending on the first input => we need to run eager with tracing inputs.
# But other test cases do not need to go VB.
def _test_kjt_input_module_dynamo_compile(
self,
kjt_input_module: torch.nn.Module,
kjt_keys: List[str],
# pyre-ignore
inputs,
backend: str = "eager",
) -> None:
with dynamo_skipfiles_allow("torchrec"):
EM: torch.nn.Module = KJTInputExportWrapperWithStrides(
kjt_input_module, kjt_keys
)
eager_output = EM(*inputs)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

dynamo_eager_out = torch.compile(EM, backend=backend, fullgraph=True)(
*inputs
)
assert_close(eager_output, dynamo_eager_out)

@given(
test_type_backend=st.sampled_from(
[(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")]
)
)
@settings(deadline=None)
def test_kjt_split(self, test_type_backend: Tuple[_TestType, str]) -> None:
test_type, backend = test_type_backend

class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
return kjt.split([1, 2, 1])

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
if test_type == _TestType.EXPORT:
self._test_kjt_input_module(
M(),
kjt,
(),
test_aot_inductor=False,
test_dynamo=False,
test_pt2_ir_export=True,
)
elif test_type == _TestType.DYNAMO_COMPILE:
self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)),
backend=backend,
)

self._test_kjt_input_module(
M(),
kjt,
(),
test_aot_inductor=False,
test_dynamo=False,
test_pt2_ir_export=True,
@given(
test_type_backend=st.sampled_from(
[(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")]
)
)
@settings(deadline=None)
def test_kjt_permute(self, test_type_backend: Tuple[_TestType, str]) -> None:
test_type, backend = test_type_backend

def test_kjt_permute(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
return kjt.permute(indices)

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
indices: List[int] = [1, 0, 3, 2]
self._test_kjt_input_module(
M(),
kjt,
(indices,),
test_aot_inductor=False,
test_pt2_ir_export=True,
)

if test_type == _TestType.EXPORT:
self._test_kjt_input_module(
M(),
kjt,
(indices,),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
elif test_type == _TestType.DYNAMO_COMPILE:

def inputs_fn(kjt):
return *kjt_module_kjt_inputs_with_strides(kjt), indices

self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
inputs_fn(kjt_for_pt2_tracing(kjt)),
backend=backend,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_kt_regroup_as_dict(
self,
) -> None:

class M(torch.nn.Module):
def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
keys = ["group_0", "group_1"]
return KeyedTensor.regroup_as_dict(inputs, groups, keys)

m = M()

key_dim = 1
tensor_list_1 = [torch.randn(2, 3) for i in range(3)]
keys_1 = ["dense_0", "dense_1", "dense_2"]
kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim)
tensor_list_2 = [torch.randn(2, 3) for i in range(2)]
keys_2 = ["sparse_0", "sparse_1"]
kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim)
inputs = [kt_1, kt_2]

for t in itertools.chain(tensor_list_1, tensor_list_2):
torch._dynamo.decorators.mark_dynamic(t, 0)
torch._dynamo.decorators.mark_dynamic(t, 1)

eager_output = m(inputs)
with dynamo_skipfiles_allow("torchrec"):
torch_compile_backend = "eager"

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
opt_fn = torch.compile(
m, backend=torch_compile_backend, fullgraph=True, dynamic=True
)
compile_output = opt_fn(inputs)
torch.testing.assert_close(eager_output, compile_output)

def test_kjt_length_per_key(self) -> None:
class M(torch.nn.Module):
Expand Down Expand Up @@ -237,8 +361,15 @@ def forward(self, kjt: KeyedJaggedTensor):
test_pt2_ir_export=True,
)

# pyre-ignore
def test_kjt__getitem__(self) -> None:
@given(
test_type_backend=st.sampled_from(
[(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")]
)
)
@settings(deadline=None)
def test_kjt__getitem__(self, test_type_backend: Tuple[_TestType, str]) -> None:
test_type, backend = test_type_backend

class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
out0 = kjt["key0"]
Expand All @@ -249,13 +380,34 @@ def forward(self, kjt: KeyedJaggedTensor):
# First element represents symint for values and weights shape
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module(
if test_type == _TestType.EXPORT:
self._test_kjt_input_module(
M(),
kjt,
(),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
)
elif test_type == _TestType.DYNAMO_COMPILE:
self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)),
backend=backend,
)

def test_kjt_to_dict_with_strides_dynamo(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
return kjt.to_dict()

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module_dynamo_compile(
M(),
kjt,
(),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
kjt.keys(),
kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)),
)

# pyre-ignores
Expand Down
56 changes: 56 additions & 0 deletions torchrec/pt2/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import List

import torch


try:
if torch.jit.is_scripting():
raise Exception()

from torch.compiler import (
is_compiling as is_compiler_compiling,
is_dynamo_compiling as is_torchdynamo_compiling,
)

def is_non_strict_exporting() -> bool:
return not is_torchdynamo_compiling() and is_compiler_compiling()

except Exception:
# BC for torch versions without compiler and torch deploy path
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False

def is_non_strict_exporting() -> bool:
return False


def pt2_checks_tensor_slice(
tensor: torch.Tensor, start_offset: int, end_offset: int, dim: int = 0
) -> None:
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
return

torch._check_is_size(start_offset)
torch._check_is_size(end_offset)
torch._check_is_size(end_offset - start_offset)
torch._check(start_offset <= tensor.size(dim))
torch._check(end_offset <= tensor.size(dim))
torch._check(end_offset >= start_offset)


def pt2_checks_all_is_size(list: List[int]) -> List[int]:
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
return list

for i in list:
torch._check_is_size(i)
return list
Loading

0 comments on commit 9fd4bc3

Please sign in to comment.