Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Make _to_consolidated compatible with compile #1041

Open
wants to merge 42 commits into
base: gh/vmoens/30/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)
if is_dynamo_compiling():
return self._to_reconstruct_compiled(
storage, storage_cast, device, num_threads, non_blocking
)
return self._to_reconstruct(
storage, storage_cast, device, num_threads, non_blocking
)

def _to_reconstruct(self, storage, storage_cast, device, num_threads, non_blocking):
untyped_storage = storage_cast.untyped_storage()

def set_(x):
if x.is_nested:
if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
vmoens marked this conversation as resolved.
Show resolved Hide resolved
)
values = x._values
lengths = x._lengths
offsets = x._offsets
return torch.nested.nested_tensor_from_jagged(
set_(values),
offsets=set_(offsets),
lengths=set_(lengths) if lengths is not None else None,
)
storage_offset = x.storage_offset()
stride = x.stride()
return torch.empty_like(x, device=device).set_(
return x.new_empty((0,), device=device).set_(
untyped_storage,
size=x.shape,
stride=stride,
storage_offset=storage_offset,
)
# return torch.empty_like(x, device=device).set_(
# untyped_storage,
# size=x.shape,
# stride=stride,
# storage_offset=storage_offset,
# )

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
)
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _to_reconstruct_compiled(self, storage, storage_cast, device, num_threads, non_blocking):
def set_(x):
if x.is_nested:
if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
values = x._values
lengths = x._lengths
offsets = x._offsets
return torch._nested_view_from_jagged(
set_(values),
set_(offsets),
x,
lengths=set_(lengths) if lengths is not None else None,
)
storage_offset = x.storage_offset()
stride = x.stride()
index_slice = slice(storage_offset, storage_offset + x.numel(), stride[0])
return storage_cast.view(x.dtype)[index_slice].view(x.type)

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
Expand Down
Loading