Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 4af5c62 commit 90010ab
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 49 deletions.
16 changes: 8 additions & 8 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@

_has_funcdim = False
try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling

try:
from torch.nn.parameter import Buffer
Expand Down Expand Up @@ -251,7 +251,7 @@ def __init__(

self._tensordict = _StringOnlyDict()

# if names and is_dynamo_compiling():
# if names and is_compiling():
# graph_break()
has_device = device is not None
sub_non_blocking = False
Expand Down Expand Up @@ -284,7 +284,7 @@ def __init__(
)
self._batch_size = self._parse_batch_size(source, batch_size)
# TODO: this breaks when stacking tensorclasses with dynamo
if not is_dynamo_compiling():
if not is_compiling():
self.names = names

for key, value in source.items():
Expand Down Expand Up @@ -313,7 +313,7 @@ def _new_unsafe(
nested: bool = True,
**kwargs: dict[str, Any] | None,
) -> TensorDict:
if is_dynamo_compiling():
if is_compiling():
return TensorDict(
source,
batch_size=batch_size,
Expand Down Expand Up @@ -473,7 +473,7 @@ def _to_module(
is_dynamo: bool | None = None,
):
if is_dynamo is None:
is_dynamo = is_dynamo_compiling()
is_dynamo = is_compiling()
if is_dynamo:
_check_inbuild()

Expand Down Expand Up @@ -2264,7 +2264,7 @@ def _parse_batch_size(
) -> torch.Size:
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."

if is_dynamo_compiling():
if is_compiling():
if isinstance(batch_size, torch.Size):
return batch_size
elif isinstance(batch_size, tuple):
Expand Down Expand Up @@ -2316,7 +2316,7 @@ def names(self):

@names.setter
def names(self, value):
if is_dynamo_compiling():
if is_compiling():
if value is not None:
graph_break()
else:
Expand Down
10 changes: 5 additions & 5 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
)

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling

TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
Expand Down Expand Up @@ -301,7 +301,7 @@ def _cat(
out = {}
for key in keys:
items = [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts]
if not is_dynamo_compiling():
if not is_compiling():
with _ErrorInteceptor(
key, "Attempted to concatenate tensors on different devices at key"
):
Expand Down Expand Up @@ -335,7 +335,7 @@ def _cat(
_ErrorInteceptor(
key, "Attempted to concatenate tensors on different devices at key"
)
if not is_dynamo_compiling()
if not is_compiling()
else contextlib.nullcontext()
):
if isinstance(out, TensorDict):
Expand Down Expand Up @@ -592,7 +592,7 @@ def stack_fn(key, values, is_not_init, is_tensor):
_ErrorInteceptor(
key, "Attempted to stack tensors on different devices at key"
)
if not is_dynamo_compiling()
if not is_compiling()
else contextlib.nullcontext()
):
return _stack(values, dim, maybe_dense_stack=maybe_dense_stack)
Expand Down
31 changes: 14 additions & 17 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@
from torch.utils._pytree import tree_map

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling

try:
from torch import _foreach_copy_
Expand Down Expand Up @@ -5247,7 +5247,7 @@ def _view_and_pad(tensor):
if v.device != storage.device:
v = v.to(storage.device, non_blocking=non_blocking)
stride = v.stride()
if is_dynamo_compiling():
if is_compiling():
if not v.is_contiguous():
v = v.clone(memory_format=torch.contiguous_format)
elif (stride and stride[-1] != 1) or v.storage_offset():
Expand Down Expand Up @@ -6963,7 +6963,7 @@ def _values_list(
is_leaf=is_leaf,
collapse=collapse,
)
if is_dynamo_compiling():
if is_compiling():
key_to_index = {key: i for i, key in enumerate(keys)}
return [vals[key_to_index[key]] for key in sorting_keys]
else:
Expand Down Expand Up @@ -6994,7 +6994,7 @@ def _items_list(
return list(keys), list(vals)
if default is None:
# TODO: check that lists are identical
if is_dynamo_compiling():
if is_compiling():
key_to_index = {key: i for i, key in enumerate(keys)}
new_vals = [vals[key_to_index[key]] for key in sorting_keys]
if len(new_vals) < len(vals):
Expand All @@ -7015,12 +7015,9 @@ def _items_list(
] # intersection does not keep the sorting
else:
new_keys = list(set(sorting_keys).union(keys))
if is_dynamo_compiling():
...
else:
source = dict(zip(keys, vals))
vals = [source.get(key, default) for key in new_keys]
return new_keys, vals
source = dict(zip(keys, vals))
vals = [source.get(key, default) for key in new_keys]
return new_keys, vals

def _grad(self):
# We can't cache this because zero_grad can be called outside (eg from optimizer) and we want the tensors
Expand Down Expand Up @@ -11931,7 +11928,7 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
result.lock_()
return result
else:
if not is_dynamo_compiling():
if not is_compiling():
key_list = list(self.keys())
else:
key_list = [k for k in self.keys()] # noqa
Expand Down Expand Up @@ -12196,10 +12193,10 @@ def lock_(self) -> T:
"""
if self.is_locked:
return self
is_compiling = is_dynamo_compiling()
if is_compiling:
is_comp = is_compiling()
if is_comp:
_lock_warn()
self._propagate_lock(is_compiling=is_compiling)
self._propagate_lock(is_compiling=is_comp)
return self

@erase_cache
Expand Down Expand Up @@ -12611,7 +12608,7 @@ def copy_dict(d):
def _sync_all(self):
if _has_cuda:
# TODO: dynamo doesn't like torch.cuda.is_initialized
if not is_dynamo_compiling() and torch.cuda.is_initialized():
if not is_compiling() and torch.cuda.is_initialized():
torch.cuda.synchronize()
elif _has_mps:
mps = getattr(torch, "mps", None)
Expand Down Expand Up @@ -12799,7 +12796,7 @@ def _register_tensor_class(cls):


def _is_tensor_collection(datatype: type) -> bool:
is_dynamo = is_dynamo_compiling()
is_dynamo = is_compiling()
out = None
if not is_dynamo:
out = _TENSOR_COLLECTION_MEMO.get(datatype)
Expand Down
8 changes: 4 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from torch import nn, Tensor

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling

try:
from functorch import FunctionalModule, FunctionalModuleWithBuffers
Expand Down Expand Up @@ -1153,7 +1153,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}(\n{fields})"

def __getattr__(self, name: str) -> Any:
if not is_dynamo_compiling():
if not is_compiling():
__dict__ = self.__dict__
_parameters = __dict__.get("_parameters")
if _parameters:
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def __init__(self, td_module: TensorDictModuleBase) -> None:
self.register_forward_hook(self.td_module._forward_hooks[pre_hook])

def __getattr__(self, name: str) -> Any:
if not is_dynamo_compiling():
if not is_compiling():
__dict__ = self.__dict__
_parameters = __dict__.get("_parameters")
if _parameters:
Expand Down
10 changes: 5 additions & 5 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from torch.utils._contextlib import _DecoratorContextManager

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling


_dispatch_tdnn_modules = _ContextManager(
Expand Down Expand Up @@ -300,7 +300,7 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
return super().__call__(wrapper)

def __enter__(self) -> None:
if self.mode and is_dynamo_compiling():
if self.mode and is_compiling():
raise RuntimeError("skip_existing is not compatible with TorchDynamo.")
self.prev = _skip_existing.get_mode()
if self.mode is not None:
Expand Down Expand Up @@ -338,7 +338,7 @@ def __call__(self, func: Callable):

@functools.wraps(func)
def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
if skip_existing() and is_dynamo_compiling():
if skip_existing() and is_compiling():
raise RuntimeError(
"skip_existing is not compatible with torch.compile."
)
Expand All @@ -351,7 +351,7 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
and not any(key in out_keys for key in in_keys)
):
return tensordict
if is_dynamo_compiling():
if is_compiling():
return func(_self, tensordict, *args, **kwargs)
self.prev = _skip_existing.get_mode()
try:
Expand Down
20 changes: 10 additions & 10 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
from torch.utils._pytree import tree_map

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling


def _identity(cls):
Expand Down Expand Up @@ -890,7 +890,7 @@ def wrapper(
if lock is None:
lock = frozen

if not is_dynamo_compiling():
if not is_compiling():
# zip not supported by dynamo
for value, key in zip(args, self.__dataclass_fields__):
if key in kwargs:
Expand All @@ -904,7 +904,7 @@ def wrapper(

if batch_size is None:
batch_size = torch.Size([])
if not is_dynamo_compiling():
if not is_compiling():
for key, field in type(self).__dataclass_fields__.items():
if field.default_factory is not dataclasses.MISSING:
default = field.default_factory()
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa:
# tensordict = tensordict.copy()
tensor_keys = tensordict.keys()
# TODO: compile doesn't like set() over an arbitrary object
if is_dynamo_compiling():
if is_compiling():
tensor_keys = {k for k in tensor_keys} # noqa: C416
exp_keys = {k for k in cls.__expected_keys__} # noqa: C416
if non_tensordict is not None:
Expand Down Expand Up @@ -1112,7 +1112,7 @@ def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa:
for key in to_add:
non_tensordict[key] = None

if not is_dynamo_compiling():
if not is_compiling():
# bypass initialisation. this means we don't incur any overhead creating an
# empty tensordict and writing values to it. we can skip this because we already
# have a tensordict to use as the underlying tensordict
Expand Down Expand Up @@ -1313,7 +1313,7 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
value (any): the value to set for the attribute
"""
if not is_dynamo_compiling():
if not is_compiling():
__dict__ = self.__dict__
if (
"_tensordict" not in __dict__
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def deliver_result(self, result, kwargs):
if result is None:
return
if isinstance(result, TensorDictBase) and kwargs.get("out") is not result:
if not is_dynamo_compiling():
if not is_compiling():
non_tensordict = super(type(self), self).__getattribute__(
"_non_tensordict"
)
Expand All @@ -1362,7 +1362,7 @@ def deliver_result(self, result, kwargs):
return result

def wrapped_func(self, *args, **kwargs):
if not is_dynamo_compiling():
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict
Expand Down Expand Up @@ -1409,7 +1409,7 @@ def wrapped_func(*args, **kwargs):
return type(self)._from_tensordict(res, dict(self._non_tensordict))
return res

if not is_dynamo_compiling():
if not is_compiling():
wrapped_func = functools.wraps(func)(wrapped_func)

return wrapped_func
Expand Down

0 comments on commit 90010ab

Please sign in to comment.