diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index a68136014..79a890545 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -5,6 +5,7 @@ from __future__ import annotations import warnings +from typing import Dict import torch from tensordict import TensorDict, TensorDictBase @@ -121,6 +122,105 @@ def __init__( self.include_sum = include_sum self.inplace = inplace + @classmethod + def from_distributions( + cls, + params, + distributions: Dict[NestedKey, d.Distribution], + *, + name_map: dict | None = None, + aggregate_probabilities: bool | None = None, + log_prob_key: NestedKey = "sample_log_prob", + entropy_key: NestedKey = "entropy", + inplace: bool | None = None, + include_sum: bool | None = None, + ) -> CompositeDistribution: + """Create a `CompositeDistribution` instance from existing distribution objects. + + This class method allows for the creation of a `CompositeDistribution` by directly providing + a dictionary of distribution instances, rather than specifying distribution types and parameters separately. + + Args: + params (TensorDictBase): A TensorDict that defines the batch shape for the composite distribution. + The params will not be used by this method, but the tensordict will be used to gather the key names of + the distributions. + distributions (Dict[NestedKey, d.Distribution]): A dictionary mapping nested keys to distribution instances. + These distributions will be used directly in the composite distribution. + + Keyword Args: + name_map (Dict[NestedKey, NestedKey], optional): A mapping of where each sample should be written. If not provided, + the key names from `distribution_map` will be used. + aggregate_probabilities (bool, optional): If `True`, the `log_prob` and `entropy` methods will sum the probabilities and entropies + of the individual distributions and return a single tensor. If `False`, individual log-probabilities will be stored in the input + TensorDict (for `log_prob`) or returned as leaves of the output TensorDict (for `entropy`). This can be overridden at runtime + by passing the `aggregate_probabilities` argument to `log_prob` and `entropy`. Defaults to `False`. + log_prob_key (NestedKey, optional): The key where the log probability will be stored. Defaults to `'sample_log_prob'`. + entropy_key (NestedKey, optional): The key where the entropy will be stored. Defaults to `'entropy'`. + inplace (bool, optional): Whether to modify the input TensorDict in-place. Defaults to `True`. + + .. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor. + + include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. Defaults to `True`. + + .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. + + Returns: + CompositeDistribution: An instance of `CompositeDistribution` initialized with the provided distributions. + + Raises: + KeyError: If a key in `name_map` cannot be found in the provided distributions. + + .. note:: The batch size of the `params` TensorDict determines the batch shape of the composite distribution. + + Example: + >>> from tensordict.nn import CompositeDistribution, ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule, TensorDictModule + >>> import torch + >>> from tensordict import TensorDict + >>> + >>> # Values are not used to build the dists + >>> params = TensorDict({("0", "loc"): None, ("1", "loc"): None, ("0", "scale"): None, ("1", "scale"): None}) + >>> d0 = torch.distributions.Normal(0, 1) + >>> d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2)) + >>> + >>> d = CompositeDistribution.from_distributions(params, {"0": d0, "1": d1}) + >>> print(d.sample()) + TensorDict( + fields={ + 0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + 1: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + self = cls.__new__(cls) + self._batch_shape = params.shape + dists = {} + if name_map is not None: + name_map = { + unravel_key(key): unravel_key(other_key) + for key, other_key in name_map.items() + } + for name, dist in distributions.items(): + name_unravel = unravel_key(name) + if name_map: + try: + write_name = unravel_key(name_map.get(name, name_unravel)) + except KeyError: + raise KeyError( + f"Failed to retrieve the key {name} from the name_map with keys {name_map.keys()}." + ) + else: + write_name = name_unravel + dists[write_name] = dist + self.dists = dists + self.log_prob_key = log_prob_key + self.entropy_key = entropy_key + + self.aggregate_probabilities = aggregate_probabilities + self.include_sum = include_sum + self.inplace = inplace + return self + @property def aggregate_probabilities(self): aggregate_probabilities = self._aggregate_probabilities diff --git a/test/test_nn.py b/test/test_nn.py index 4bab031a9..cdbab76e9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2166,6 +2166,29 @@ def test_sample(self): sample = dist.sample((4,)) assert sample.shape == torch.Size((4,) + params.shape) + def test_from_distributions(self): + + # Values are not used to build the dists + params = TensorDict( + { + ("0", "loc"): None, + ("1", "nested", "loc"): None, + ("0", "scale"): None, + ("1", "nested", "scale"): None, + } + ) + d0 = torch.distributions.Normal(0, 1) + d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2)) + + d = CompositeDistribution.from_distributions( + params, {"0": d0, ("1", "nested"): d1} + ) + s = d.sample() + assert s["0"].shape == () + assert s["1", "nested"].shape == (1, 2) + assert isinstance(s["0"], torch.Tensor) + assert isinstance(s["1", "nested"], torch.Tensor) + def test_sample_named(self): params = TensorDict( {