Skip to content

Commit

Permalink
[Performance] Make _to_consolidated compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: bb7f342eee12ef61224c32aa0df6e93efaa1b117
Pull Request resolved: #1041
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent ee49fc7 commit 8a9de10
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 22 deletions.
38 changes: 33 additions & 5 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch._dynamo.reset_code_caches()
yield


@pytest.fixture
def td():
return TensorDict(
Expand Down Expand Up @@ -52,20 +58,42 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"consolidated,compiled", [[False, False], [True, False], [True, True]]
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
def test_to(self, benchmark, consolidated, td, default_device, compiled):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
def to(td):
return td.to(default_device)

if compiled:
to = torch.compile(to)

for _ in range(3):
to(td)

benchmark(to, td)

def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))

def to(td):
return td.to(default_device)

if compiled:
to = torch.compile(to)

for _ in range(3):
to(njt_td)

benchmark(to, njt_td)


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class MyTensorClass:
f: torch.Tensor


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch._dynamo.reset_code_caches()
yield


# Functions
def add_one(td):
return td + 1
Expand Down
187 changes: 170 additions & 17 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
_split_tensordict,
_td_fields,
_unravel_key_to_tuple,
_zip_strict,
_zip_strict,_to_escape_compile,
cache,
convert_ellipsis_to_idx,
DeviceType,
Expand Down Expand Up @@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):

flat_size = []
start = 0
sorting_index = 0

def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
nonlocal start
nonlocal start, sorting_index
n = value.element_size() * value.numel()
if need_padding:
pad = n % 8
Expand All @@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
start,
stop,
pad,
flat_size[-1],
sorting_index,
)
sorting_index = sorting_index + 1
start = stop

def assign(
Expand Down Expand Up @@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
pin_memory=non_blocking_pin,
num_threads=num_threads,
non_blocking=non_blocking,
compilable=is_dynamo_compiling(),
)

if non_blocking is None:
Expand Down Expand Up @@ -10498,14 +10503,42 @@ def to_pinmem(tensor, _to=to):
self._sync_all()
return result

def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
def _to_consolidated(
self, *, device, pin_memory, num_threads, non_blocking, compilable
):
if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0

storage = self._consolidated["storage"]
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)

storage_cast = _to_escape_compile(storage)

if compilable:
result = self._to_consolidated_compile(
device=device, num_threads=num_threads, storage_cast=storage_cast
)
else:
result = self._to_consolidated_eager(
device=device, num_threads=num_threads, storage_cast=storage_cast
)

if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _to_consolidated_eager(self, *, device, num_threads, storage_cast):

untyped_storage = storage_cast.untyped_storage()

def set_(x):
Expand Down Expand Up @@ -10574,18 +10607,138 @@ def copy_dict(d):
}

result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()
return result

def _to_consolidated_compile(self, *, device, num_threads, storage_cast):

def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
root = False
if lengths is None:
lengths = []
pos = []
keys = []
root = True
for k, v in metadata["leaves"].items():
lengths.append(v[-2])
pos.append(v[-1])
keys.append(prefix + (k,))
for k, d in metadata.items():
if "leaves" in d:
get_tensors_length(
d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)
)
if root:
# l = torch.empty(len(lengths), dtype=torch.long)
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
out0 = [
None,
] * len(pos)
out1 = [
None,
] * len(pos)
for p, l, k in zip(pos, lengths, keys):
out0[p] = k
out1[p] = l
return out0, out1

def split_storage(consolidated):
keys, splits = get_tensors_length(consolidated["metadata"])
return dict(zip(keys, consolidated["storage"].split(splits)))

if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0

_consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
# faster than deepcopy
def copy_dict(d):
return {
k: v if not isinstance(v, dict) else copy_dict(v)
for k, v in d.items()
}

_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])

slice_map = split_storage(_consolidated)

def view_as(src, dest):
return src.view(dest.dtype)[: dest.numel()].view(dest.shape)

def set_(name, x):
if not isinstance(name, tuple):
name = (name,)
if x.is_nested:
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.nested._internal.nested_tensor import (
_tensor_symint_registry,
NestedTensor,
)
from torch.nested._internal.ops import extract_kwargs

if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
kwargs = extract_kwargs(x)
values = x._values
lengths = x._lengths
offsets = x._offsets
storage_offsets = slice_map[
(
*name[:-1],
"<NJT_OFFSETS>" + name[-1],
)
]
kwargs["offsets"] = view_as(storage_offsets, offsets)
if lengths is not None:
storage_lengths = slice_map[
(
*name[:-1],
"<NJT_LENGTHS>" + name[-1],
)
]
kwargs["lengths"] = view_as(storage_lengths, lengths)
ragged_source = lengths
else:
ragged_source = offsets
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
from torch._subclasses.functional_tensor import (
mb_unwrap_functional_tensor,
)

# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
ragged_source
]

storage_values = slice_map[
(
*name[:-1],
"<NJT_VALUES>" + name[-1],
)
]
return NestedTensor(
view_as(storage_values, values),
**kwargs,
)
return view_as(slice_map[name], x)

result = self._fast_apply(
set_,
device=torch.device(device),
num_threads=num_threads,
named=True,
nested_keys=True,
)
result._consolidated = _consolidated
return result

def _sync_all(self):
Expand Down
8 changes: 8 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,3 +2694,11 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths):
values,
**kwargs,
)


@torch.compiler.disable()
def _to_escape_compile(storage, device, pin_memory):
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)
return storage_cast

0 comments on commit 8a9de10

Please sign in to comment.