Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 18, 2024
2 parents 1bd2406 + a1fe539 commit 5584ffa
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 25 deletions.
65 changes: 46 additions & 19 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import re
import warnings
from collections.abc import MutableSequence

from textwrap import indent
from typing import Any, Dict, List, Optional, overload, OrderedDict
from typing import Any, Dict, List, Optional, OrderedDict, overload

import torch

Expand Down Expand Up @@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
log(p(z | x, y))
Args:
*modules (sequence of TensorDictModules): An ordered sequence of
:class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of
:class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
to be run sequentially.
The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
Keyword Args:
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
Expand Down Expand Up @@ -794,14 +798,13 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
@overload
def __init__(
self,
modules: OrderedDict,
modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule],
partial_tolerant: bool = False,
return_composite: bool | None = None,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...
) -> None: ...

@overload
def __init__(
Expand All @@ -812,8 +815,7 @@ def __init__(
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...
) -> None: ...

def __init__(
self,
Expand All @@ -829,7 +831,14 @@ def __init__(
"ProbabilisticTensorDictSequential must consist of zero or more "
"TensorDictModules followed by a ProbabilisticTensorDictModule"
)
if not return_composite and not isinstance(
self._ordered_dict = False
if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)):
if isinstance(modules[0], OrderedDict):
modules_list = list(modules[0].values())
self._ordered_dict = True
else:
modules = modules_list = list(modules[0])
elif not return_composite and not isinstance(
modules[-1],
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
Expand All @@ -838,13 +847,22 @@ def __init__(
"an instance of ProbabilisticTensorDictModule or another "
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
)
else:
modules_list = list(modules)

# if the modules not including the final probabilistic module return the sampled
# key we won't be sampling it again, in that case
# ProbabilisticTensorDictSequential is presumably used to return the
# distribution using `get_dist` or to sample log_probabilities
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
self._requires_sample = modules[-1].out_keys[0] not in set(out_keys)
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
_, out_keys = self._compute_in_and_out_keys(modules_list[:-1])
self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys)
if self._ordered_dict:
self.__dict__["_det_part"] = TensorDictSequential(
OrderedDict(list(modules[0].items())[:-1])
)
else:
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])

super().__init__(*modules, partial_tolerant=partial_tolerant)
self.return_composite = return_composite
self.aggregate_probabilities = aggregate_probabilities
Expand Down Expand Up @@ -885,7 +903,7 @@ def get_dist_params(
tds = self.det_part
type = interaction_type()
if type is None:
for m in reversed(self.module):
for m in reversed(list(self._module_iter())):
if hasattr(m, "default_interaction_type"):
type = m.default_interaction_type
break
Expand All @@ -897,7 +915,7 @@ def get_dist_params(
@property
def num_samples(self):
num_samples = ()
for tdm in self.module:
for tdm in self._module_iter():
if isinstance(
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
):
Expand Down Expand Up @@ -941,7 +959,7 @@ def get_dist(

td_copy = tensordict.copy()
dists = {}
for i, tdm in enumerate(self.module):
for i, tdm in enumerate(self._module_iter()):
if isinstance(
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
):
Expand Down Expand Up @@ -981,12 +999,21 @@ def default_interaction_type(self):
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
"""
for m in reversed(self.module):
for m in reversed(list(self._module_iter())):
interaction = getattr(m, "default_interaction_type", None)
if interaction is not None:
return interaction
return interaction_type()

@property
def _last_module(self):
if not self._ordered_dict:
return self.module[-1]
mod = None
for mod in self._module_iter(): # noqa: B007
continue
return mod

def log_prob(
self,
tensordict,
Expand Down Expand Up @@ -1103,7 +1130,7 @@ def log_prob(
include_sum=include_sum,
**kwargs,
)
last_module: ProbabilisticTensorDictModule = self.module[-1]
last_module: ProbabilisticTensorDictModule = self._last_module
out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs)
if is_tensor_collection(out):
if tensordict_out is not None:
Expand Down Expand Up @@ -1162,7 +1189,7 @@ def forward(
else:
tensordict_exec = tensordict
if self.return_composite:
for m in self.module:
for m in self._module_iter():
if isinstance(
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
):
Expand All @@ -1173,7 +1200,7 @@ def forward(
tensordict_exec = m(tensordict_exec, **kwargs)
else:
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec = self._last_module(
tensordict_exec, _requires_sample=self._requires_sample
)
if tensordict_out is not None:
Expand Down
8 changes: 6 additions & 2 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ class TensorDictSequential(TensorDictModule):
buffers) will be concatenated in a single list.
Args:
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]):
ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase.
These can be instances of TensorDictModuleBase or any other function that matches this signature.
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
Keyword Args:
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
If so, the only module that will be executed are those who can be executed given the keys that
are present.
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
looking for those that have the required keys, if any.
looking for those that have the required keys, if any. Defaults to False.
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
``out_keys`` will be written.
Expand Down
22 changes: 18 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):

in_keys = ["in"]
net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys)
corr = TensorDictModule(
lambda low: max_dist - low.abs(), in_keys=out_keys, out_keys=out_keys
)

kwargs = {
"distribution_class": distributions.Uniform,
Expand All @@ -494,7 +497,7 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
in_keys=dist_in_keys, out_keys=["out"], **kwargs
)

tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)
tensordict_module = ProbabilisticTensorDictSequential(net, corr, prob_module)
assert tensordict_module.default_interaction_type is not None

td = TensorDict({"in": torch.randn(3, 3)}, [3])
Expand Down Expand Up @@ -2156,6 +2159,8 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key):
in_keys=[("data", "states")],
out_keys=[("data", "scale")],
)
scale_module.module.weight.data.abs_()
scale_module.module.bias.data.abs_()
td = TensorDict(
{"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]
)
Expand Down Expand Up @@ -3019,7 +3024,8 @@ def test_prob_module_nested(self, interaction, map_names):
"interaction", [InteractionType.MODE, InteractionType.MEAN]
)
@pytest.mark.parametrize("return_log_prob", [True, False])
def test_prob_module_seq(self, interaction, return_log_prob):
@pytest.mark.parametrize("ordereddict", [True, False])
def test_prob_module_seq(self, interaction, return_log_prob, ordereddict):
params = TensorDict(
{
"params": {
Expand All @@ -3042,7 +3048,7 @@ def test_prob_module_seq(self, interaction, return_log_prob):
("nested", "cont"): distributions.Normal,
}
backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[])
module = ProbabilisticTensorDictSequential(
args = [
backbone,
ProbabilisticTensorDictModule(
in_keys=in_keys,
Expand All @@ -3052,7 +3058,15 @@ def test_prob_module_seq(self, interaction, return_log_prob):
default_interaction_type=interaction,
return_log_prob=return_log_prob,
),
)
]
if ordereddict:
args = [
OrderedDict(
backbone=args[0],
proba=args[1],
)
]
module = ProbabilisticTensorDictSequential(*args)
sample = module(params)
if return_log_prob:
assert "cont_log_prob" in sample.keys()
Expand Down

0 comments on commit 5584ffa

Please sign in to comment.