Skip to content

Commit

Permalink
[Performance] Make _to_consolidated compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: c0f61164cb144e5bbc750a697a3920dfce461dc9
Pull Request resolved: #1041
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent fe6db77 commit c730f7a
Showing 1 changed file with 141 additions and 17 deletions.
158 changes: 141 additions & 17 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):

flat_size = []
start = 0
sorting_index = 0

def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
nonlocal start
nonlocal start, sorting_index
n = value.element_size() * value.numel()
if need_padding:
pad = n % 8
Expand All @@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
start,
stop,
pad,
flat_size[-1],
sorting_index,
)
sorting_index = sorting_index + 1
start = stop

def assign(
Expand Down Expand Up @@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
pin_memory=non_blocking_pin,
num_threads=num_threads,
non_blocking=non_blocking,
compilable=is_dynamo_compiling(),
)

if non_blocking is None:
Expand Down Expand Up @@ -10452,14 +10457,42 @@ def to_pinmem(tensor, _to=to):
self._sync_all()
return result

def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking, compilable):
if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0

storage = self._consolidated["storage"]
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)

@torch.compiler.disable()
def to(storage):
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)
return storage_cast
storage_cast = to(storage)

if compilable:
result = self._to_consolidated_compile(device=device, num_threads=num_threads, storage_cast=storage_cast)
else:
result = self._to_consolidated_eager(device=device, num_threads=num_threads, storage_cast=storage_cast)

if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _to_consolidated_eager(self, *, device, num_threads, storage_cast):

untyped_storage = storage_cast.untyped_storage()

def set_(x):
Expand Down Expand Up @@ -10528,20 +10561,111 @@ def copy_dict(d):
}

result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _to_consolidated_compile(self, *, device, num_threads, storage_cast):

def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
root = False
if lengths is None:
lengths = []
pos = []
keys = []
root = True
for k, v in metadata["leaves"].items():
lengths.append(v[-2])
pos.append(v[-1])
keys.append(prefix + (k,))
for k, d in metadata.items():
if "leaves" in d:
get_tensors_length(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,))
if root:
# l = torch.empty(len(lengths), dtype=torch.long)
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
out0 = [None, ] * len(pos)
out1 = [None, ] * len(pos)
for p, l, k in zip(pos, lengths, keys):
out0[p] = k
out1[p] = l
return out0, out1

def split_storage(consolidated):
keys, splits = get_tensors_length(consolidated["metadata"])
return dict(zip(keys, consolidated["storage"].split(splits)))

if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0

_consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
# faster than deepcopy
def copy_dict(d):
return {
k: v if not isinstance(v, dict) else copy_dict(v)
for k, v in d.items()
}

_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])

slice_map = split_storage(_consolidated)

def set_(name, x):
if not isinstance(name, tuple):
name = (name,)
if x.is_nested:
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.nested._internal.nested_tensor import (
_tensor_symint_registry,
NestedTensor,
)
from torch.nested._internal.ops import extract_kwargs

if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
kwargs = extract_kwargs(x)
values = x._values
lengths = x._lengths
offsets = x._offsets
storage_offsets = slice_map[(*name[:-1], "<NJT_OFFSETS>"+name[-1],)]
kwargs["offsets"] = storage_offsets.view(offsets.dtype).view(offsets.shape)
if lengths is not None:
storage_lengths = slice_map[(*name[:-1], "<NJT_LENGTHS>"+name[-1],)]
kwargs["lengths"] = storage_lengths.view(lengths.dtype).view(lengths.shape)
ragged_source = lengths
else:
ragged_source = offsets
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
from torch._subclasses.functional_tensor import (
mb_unwrap_functional_tensor,
)

# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
ragged_source
]

storage_values = slice_map[(*name[:-1], "<NJT_VALUES>"+name[-1],)]
return NestedTensor(
storage_values.view(values.dtype).view(values.shape),
**kwargs,
)
return slice_map[name].view(x.dtype).view(x.shape)

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True,
)
result._consolidated = _consolidated
return result
def _sync_all(self):
if _has_cuda:
# TODO: dynamo doesn't like torch.cuda.is_initialized
Expand Down

0 comments on commit c730f7a

Please sign in to comment.