Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Make _to_consolidated compatible with compile #1041

Open
wants to merge 42 commits into
base: gh/vmoens/30/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 170 additions & 24 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,40 @@
# LICENSE file in the root directory of this source tree.

import argparse
import time
from typing import Any

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict import tensorclass, TensorDict
from tensordict.utils import logger as tensordict_logger

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
return TensorDict(
{
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
for i in range(16)
},
batch_size=[16],
device="cpu",
)
@tensorclass
class NJT:
_values: torch.Tensor
_offsets: torch.Tensor
_lengths: torch.Tensor
njt_shape: Any = None

@classmethod
def from_njt(cls, njt_tensor):
return cls(
_values=njt_tensor._values,
_offsets=njt_tensor._offsets,
_lengths=njt_tensor._lengths,
njt_shape=njt_tensor.size(0),
).clone()


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch.compiler.reset()
yield


def _make_njt():
Expand All @@ -34,14 +48,29 @@ def _make_njt():
)


@pytest.fixture
def njt_td():
def _njt_td():
return TensorDict(
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
{str(i): _make_njt() for i in range(128)},
device="cpu",
)


@pytest.fixture
def njt_td():
return _njt_td()


@pytest.fixture
def td():
njtd = _njt_td()
for k0, v0 in njtd.items():
njtd[k0] = NJT.from_njt(v0)
# for k1, v1 in v0.items():
# njtd[k0, k1] = NJT.from_njt(v1)
return njtd


@pytest.fixture
def default_device():
if torch.cuda.is_available():
Expand All @@ -52,22 +81,139 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"compile_mode,num_threads",
[
[False, None],
# [False, 4],
# [False, 16],
["default", None],
["reduce-overhead", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestConsolidate:
def test_consolidate(self, benchmark, td, compile_mode, num_threads):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
consolidate = torch.compile(
consolidate, mode=compile_mode, dynamic=True, fullgraph=True
)

t0 = time.time()
consolidate(td, num_threads=num_threads)
elapsed = time.time() - t0
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")

for _ in range(3):
consolidate(td, num_threads=num_threads)

benchmark(consolidate, td, num_threads)

def test_to_njt(self, benchmark, njt_td, compile_mode, num_threads):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)

for _ in range(3):
consolidate(njt_td, num_threads=num_threads)

benchmark(consolidate, njt_td, num_threads)


@pytest.mark.parametrize(
"consolidated,compile_mode,num_threads",
[
[False, False, None],
[True, False, None],
["within", False, None],
# [True, False, 4],
# [True, False, 16],
[True, "default", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
)
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))
def test_to(
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))
def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(td, num_threads=num_threads)

benchmark(to, td, num_threads)

def test_to_njt(
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(njt_td, num_threads=num_threads)

benchmark(to, njt_td, num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
pytest.main(
[
__file__,
"--capture",
"no",
"--exitfirst",
"--benchmark-group-by",
"func",
"-vvv",
]
+ unknown
)
6 changes: 6 additions & 0 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class MyTensorClass:
f: torch.Tensor


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch._dynamo.reset_code_caches()
yield


# Functions
def add_one(td):
return td + 1
Expand Down
2 changes: 1 addition & 1 deletion tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _make_td(cls, state):

def _reduce_td(data: TensorDict):
consolidated = getattr(data, "_consolidated", None)
if consolidated and consolidated["metadata"] is not None:
if isinstance(consolidated, dict):
storage = consolidated["storage"]
storge_metadata = consolidated["metadata"]
return (
Expand Down
7 changes: 4 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -4210,7 +4210,7 @@ def _iter():
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
if self.is_leaf(target_class):
continue
yield key
else:
Expand Down Expand Up @@ -4239,9 +4239,10 @@ def _iter_helper(
# For lazy stacks
value = value[0]
cls = type(value)
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
is_tc = _is_tensor_collection(cls)
if self.include_nested and is_tc:
yield from self._iter_helper(value, prefix=full_key)
is_leaf = self.is_leaf(cls)
if not self.leaves_only or is_leaf:
yield full_key

Expand Down
Loading
Loading