Skip to content

Commit

Permalink
[Performance] Make _to_consolidated compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: dea67cba6a5e6b33a22a225cf48bd35afb439fac
Pull Request resolved: #1041
  • Loading branch information
vmoens committed Oct 18, 2024
1 parent 75b33c4 commit 1d92ab1
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 122 deletions.
131 changes: 108 additions & 23 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,39 @@
# LICENSE file in the root directory of this source tree.

import argparse
from typing import Any

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict import tensorclass, TensorDict
from tensordict.utils import logger as tensordict_logger

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
return TensorDict(
{
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
for i in range(16)
},
batch_size=[16],
device="cpu",
)
@tensorclass
class NJT:
_values: torch.Tensor
_offsets: torch.Tensor
_lengths: torch.Tensor
njt_shape: Any = None

@classmethod
def from_njt(cls, njt_tensor):
return NJT(
_values=njt_tensor._values,
_offsets=njt_tensor._offsets,
_lengths=njt_tensor._lengths,
njt_shape=njt_tensor.size(0),
)


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch._dynamo.reset_code_caches()
yield


def _make_njt():
Expand All @@ -34,14 +47,27 @@ def _make_njt():
)


@pytest.fixture
def njt_td():
def _njt_td():
return TensorDict(
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
device="cpu",
)


@pytest.fixture
def njt_td():
return _njt_td()


@pytest.fixture
def td():
njtd = _njt_td()
for k0, v0 in njtd.items():
for k1, v1 in v0.items():
njtd[k0, k1] = NJT.from_njt(v1)
return njtd


@pytest.fixture
def default_device():
if torch.cuda.is_available():
Expand All @@ -52,22 +78,81 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"consolidated,compile_mode,num_threads",
[
[False, False, None],
[True, False, None],
["within", False, None],
# [True, False, 4],
# [True, False, 16],
# [True, "default", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))
def test_to(
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem, set_on_tensor=True)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode)

for _ in range(3):
to(td, num_threads=num_threads)

benchmark(to, td, num_threads)

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))
def test_to_njt(
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem, set_on_tensor=True)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode)

for _ in range(3):
to(njt_td, num_threads=num_threads)

benchmark(to, njt_td, num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
pytest.main(
[__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
+ unknown
)
6 changes: 6 additions & 0 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class MyTensorClass:
f: torch.Tensor


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch._dynamo.reset_code_caches()
yield


# Functions
def add_one(td):
return td + 1
Expand Down
Loading

0 comments on commit 1d92ab1

Please sign in to comment.