diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index d449018a8..54017ed2b 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -107,12 +107,18 @@ def _get_args_dict(func, args, kwargs): def _maybe_make_param(tensor): - if ( - isinstance(tensor, (Tensor, ftdim.Tensor)) - and not isinstance(tensor, nn.Parameter) - and tensor.dtype in (torch.float, torch.double, torch.half) + if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance( + tensor, (nn.Parameter, Buffer, BufferLegacy) ): - tensor = nn.Parameter(tensor) + if tensor.dtype in (torch.float, torch.double, torch.half): + tensor = nn.Parameter(tensor) + elif not is_batchedtensor(tensor): + # convert all non-parameters to buffers + # dataptr = tensor.data.data_ptr() + tensor = Buffer(tensor) + else: + # We want to keep the grad_fn of tensors, e.g. param.expand(10) should point to the original param + tensor = BufferLegacy(tensor) return tensor diff --git a/test/test_nn.py b/test/test_nn.py index 5c32ac6e3..ba9faf93d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1896,6 +1896,51 @@ def test_td_params(self): assert (m.params == params).all() assert (params == m.params).all() + def test_constructors(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.register_parameter( + "param", nn.Parameter(torch.randn(3, requires_grad=True)) + ) + self.register_buffer("buf", torch.randn(3)) + self.register_buffer("buf_int", torch.randint(3, ())) + + td = TensorDict.from_module(MyModule()) + assert not isinstance(td, TensorDictParams) + td = TensorDictParams(td) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], nn.Parameter) + assert isinstance(td["buf_int"], Buffer) + td = TensorDict.from_module(MyModule()) + assert not isinstance(td, TensorDictParams) + td = TensorDictParams(td, no_convert=True) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], Buffer) + assert isinstance(td["buf_int"], Buffer) + + td = TensorDict.from_module(MyModule(), as_module=True) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], Buffer) + assert isinstance(td["buf_int"], Buffer) + + tdparams = TensorDictParams(a=0, b=1.0) + assert isinstance(tdparams["a"], Buffer) + assert isinstance(tdparams["b"], nn.Parameter) + + tdparams = TensorDictParams({"a": 0, "b": 1.0}) + assert isinstance(tdparams["a"], Buffer) + assert isinstance(tdparams["b"], nn.Parameter) + tdparams_copy = tdparams.copy() + + def assert_is_identical(a, b): + assert a is b + + tdparams.apply(assert_is_identical, tdparams_copy, filter_empty=True) + def test_td_params_cast(self): params = self._get_params() p = TensorDictParams(params)