Skip to content

Commit

Permalink
[Performance] Make _to_consolidated compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 17f1ce893b6f14b990b59703447301494b5f7585
Pull Request resolved: #1041
  • Loading branch information
vmoens committed Oct 14, 2024
1 parent 7e45bcc commit b5736ab
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)
if is_dynamo_compiling():
return self._to_reconstruct_compiled(
storage, storage_cast, device, num_threads, non_blocking
)
return self._to_reconstruct(
storage, storage_cast, device, num_threads, non_blocking
)

def _to_reconstruct(self, storage, storage_cast, device, num_threads, non_blocking):
untyped_storage = storage_cast.untyped_storage()

def set_(x):
if x.is_nested:
if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
values = x._values
lengths = x._lengths
offsets = x._offsets
return torch.nested.nested_tensor_from_jagged(
set_(values),
offsets=set_(offsets),
lengths=set_(lengths) if lengths is not None else None,
)
storage_offset = x.storage_offset()
stride = x.stride()
return torch.empty_like(x, device=device).set_(
return x.new_empty((0,), device=device).set_(
untyped_storage,
size=x.shape,
stride=stride,
storage_offset=storage_offset,
)
# return torch.empty_like(x, device=device).set_(
# untyped_storage,
# size=x.shape,
# stride=stride,
# storage_offset=storage_offset,
# )

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
)
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _to_reconstruct_compiled(self, storage, storage_cast, device, num_threads, non_blocking):
def set_(x):
if x.is_nested:
if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
values = x._values
lengths = x._lengths
offsets = x._offsets
return torch._nested_view_from_jagged(
set_(values),
set_(offsets),
x,
lengths=set_(lengths) if lengths is not None else None,
)
storage_offset = x.storage_offset()
stride = x.stride()
index_slice = slice(storage_offset, storage_offset + x.numel(), stride[0])
return storage_cast.view(x.dtype)[index_slice].view(x.type)

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
Expand Down

0 comments on commit b5736ab

Please sign in to comment.