Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 8fbf611 commit 2bc10df
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,18 @@ def _get_args_dict(func, args, kwargs):


def _maybe_make_param(tensor):
if (
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(
tensor, (nn.Parameter, Buffer, BufferLegacy)
):
tensor = nn.Parameter(tensor)
if tensor.dtype in (torch.float, torch.double, torch.half):
tensor = nn.Parameter(tensor)
elif not is_batchedtensor(tensor):
# convert all non-parameters to buffers
# dataptr = tensor.data.data_ptr()
tensor = Buffer(tensor)
else:
# We want to keep the grad_fn of tensors, e.g. param.expand(10) should point to the original param
tensor = BufferLegacy(tensor)
return tensor


Expand Down
45 changes: 45 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,51 @@ def test_td_params(self):
assert (m.params == params).all()
assert (params == m.params).all()

def test_constructors(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter(
"param", nn.Parameter(torch.randn(3, requires_grad=True))
)
self.register_buffer("buf", torch.randn(3))
self.register_buffer("buf_int", torch.randint(3, ()))

td = TensorDict.from_module(MyModule())
assert not isinstance(td, TensorDictParams)
td = TensorDictParams(td)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], nn.Parameter)
assert isinstance(td["buf_int"], Buffer)
td = TensorDict.from_module(MyModule())
assert not isinstance(td, TensorDictParams)
td = TensorDictParams(td, no_convert=True)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], Buffer)
assert isinstance(td["buf_int"], Buffer)

td = TensorDict.from_module(MyModule(), as_module=True)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], Buffer)
assert isinstance(td["buf_int"], Buffer)

tdparams = TensorDictParams(a=0, b=1.0)
assert isinstance(tdparams["a"], Buffer)
assert isinstance(tdparams["b"], nn.Parameter)

tdparams = TensorDictParams({"a": 0, "b": 1.0})
assert isinstance(tdparams["a"], Buffer)
assert isinstance(tdparams["b"], nn.Parameter)
tdparams_copy = tdparams.copy()

def assert_is_identical(a, b):
assert a is b

tdparams.apply(assert_is_identical, tdparams_copy, filter_empty=True)

def test_td_params_cast(self):
params = self._get_params()
p = TensorDictParams(params)
Expand Down

0 comments on commit 2bc10df

Please sign in to comment.