diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index a9435669b..d449018a8 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -117,9 +117,8 @@ def _maybe_make_param(tensor): def _maybe_make_param_or_buffer(tensor): - if ( - isinstance(tensor, (Tensor, ftdim.Tensor)) - and not isinstance(tensor, (nn.Parameter, Buffer)) + if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance( + tensor, (nn.Parameter, Buffer) ): if not tensor.requires_grad and not is_batchedtensor(tensor): # convert all non-parameters to buffers diff --git a/tensordict/utils.py b/tensordict/utils.py index ab4039944..0e370856f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -70,10 +70,7 @@ try: from torch.compiler import assume_constant_result, is_compiling except ImportError: # torch 2.0 - from torch._dynamo import ( - assume_constant_result, - is_compiling, - ) + from torch._dynamo import assume_constant_result, is_compiling if TYPE_CHECKING: from tensordict.tensordict import TensorDictBase @@ -2825,7 +2822,8 @@ def _is_dataclass(obj): if isinstance(obj, type) and not isinstance(obj, GenericAlias) else type(obj) ) - return hasattr(cls, _FIELDS) + # return hasattr(cls, _FIELDS) + return getattr(cls, _FIELDS, None) is not None def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: