Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 2aea3dd commit 661cad4
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 14 deletions.
77 changes: 63 additions & 14 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from __future__ import annotations

import collections
import logging
from copy import deepcopy
from typing import Any, Iterable, List
from typing import Any, Callable, Iterable, List, OrderedDict, overload

from tensordict._nestedkey import NestedKey

Expand Down Expand Up @@ -170,19 +171,57 @@ class TensorDictSequential(TensorDictModule):
module: nn.ModuleList
_select_before_return = False

@overload
def __init__(
self,
*modules: TensorDictModuleBase,
modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]],
*,
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None: ...

@overload
def __init__(
self,
modules: List[Callable[[TensorDictBase], TensorDictBase]],
*,
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None: ...

def __init__(
self,
*modules: Callable[[TensorDictBase], TensorDictBase],
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)

super().__init__(
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
)
if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict):
modules_vals = self._convert_modules(modules[0].values())
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
self._complete_out_keys = list(out_keys)
modules = collections.OrderedDict(
**{key: val for key, val in zip(modules[0], modules_vals)}
)
super().__init__(
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
)
elif len(modules) == 1 and isinstance(
modules[0], collections.abc.MutableSequence
):
modules = self._convert_modules(modules[0])
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
)
else:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
)

self.partial_tolerant = partial_tolerant
if selected_out_keys:
Expand Down Expand Up @@ -408,7 +447,7 @@ def select_subsequence(
out_keys = deepcopy(self.out_keys)
out_keys = unravel_key_list(out_keys)

module_list = list(self.module)
module_list = list(self._module_iter())
id_to_keep = set(range(len(module_list)))
for i, module in enumerate(module_list):
if (
Expand Down Expand Up @@ -445,8 +484,12 @@ def select_subsequence(
raise ValueError(
"No modules left after selection. Make sure that in_keys and out_keys are coherent."
)

return type(self)(*modules)
if isinstance(self.module, nn.ModuleList):
return type(self)(*modules)
else:
keys = [key for key in self.module if self.module[key] in modules]
modules_dict = OrderedDict(**{key: val for key, val in zip(keys, modules)})
return type(self)(modules_dict)

def _run_module(
self,
Expand All @@ -466,6 +509,12 @@ def _run_module(
module(sub_td, **kwargs)
return tensordict

def _module_iter(self):
if isinstance(self.module, nn.ModuleDict):
yield from self.module.children()
else:
yield from self.module

@dispatch(auto_batch_size=False)
@_set_skip_existing_None()
def forward(
Expand All @@ -481,7 +530,7 @@ def forward(
else:
tensordict_exec = tensordict
if not len(kwargs):
for module in self.module:
for module in self._module_iter():
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
else:
raise RuntimeError(
Expand Down Expand Up @@ -510,8 +559,8 @@ def forward(
def __len__(self) -> int:
return len(self.module)

def __getitem__(self, index: int | slice) -> TensorDictModuleBase:
if isinstance(index, int):
def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
if isinstance(index, (int, str)):
return self.module.__getitem__(index)
else:
return type(self)(*self.module.__getitem__(index))
Expand Down
53 changes: 53 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pickle
import unittest
import weakref
from collections import OrderedDict

import pytest
import torch
Expand Down Expand Up @@ -797,6 +798,58 @@ def test_tdmodule_inplace(self):


class TestTDSequence:
def test_ordered_dict(self):
linear = nn.Linear(3, 4)
linear.weight.data.fill_(0)
linear.bias.data.fill_(1)
layer0 = TensorDictModule(linear, in_keys=["x"], out_keys=["y"])
ordered_dict = OrderedDict(
layer0=layer0,
layer1=lambda x: x + 1,
)
seq = TensorDictSequential(ordered_dict)
td = seq(TensorDict(x=torch.ones(3)))
assert (td["x"] == 2).all()
assert (td["y"] == 2).all()
assert seq["layer0"] is layer0

def test_ordered_dict_select_subsequence(self):
ordered_dict = OrderedDict(
layer0=TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]),
layer1=TensorDictModule(lambda x: x - 1, in_keys=["y"], out_keys=["z"]),
layer2=TensorDictModule(
lambda x, y: x + y, in_keys=["x", "y"], out_keys=["a"]
),
)
seq = TensorDictSequential(ordered_dict)
assert len(seq) == 3
assert isinstance(seq.module, nn.ModuleDict)
seq_select = seq.select_subsequence(out_keys=["a"])
assert len(seq_select) == 2
assert isinstance(seq_select.module, nn.ModuleDict)
assert list(seq_select.module) == ["layer0", "layer2"]

def test_ordered_dict_select_outkeys(self):
ordered_dict = OrderedDict(
layer0=TensorDictModule(
lambda x: x + 1, in_keys=["x"], out_keys=["intermediate"]
),
layer1=TensorDictModule(
lambda x: x - 1, in_keys=["intermediate"], out_keys=["z"]
),
layer2=TensorDictModule(
lambda x, y: x + y, in_keys=["x", "z"], out_keys=["a"]
),
)
seq = TensorDictSequential(ordered_dict)
assert len(seq) == 3
assert isinstance(seq.module, nn.ModuleDict)
seq.select_out_keys("z", "a")
td = seq(TensorDict(x=0))
assert "intermediate" not in td
assert "z" in td
assert "a" in td

@pytest.mark.parametrize("args", [True, False])
def test_input_keys(self, args):
module0 = TensorDictModule(lambda x: x + 0, in_keys=["input"], out_keys=["1"])
Expand Down

0 comments on commit 661cad4

Please sign in to comment.