Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CompositeDistribution.from_distributions #1113

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import warnings
from typing import Dict

import torch
from tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -121,6 +122,105 @@ def __init__(
self.include_sum = include_sum
self.inplace = inplace

@classmethod
def from_distributions(
cls,
params,
distributions: Dict[NestedKey, d.Distribution],
*,
name_map: dict | None = None,
aggregate_probabilities: bool | None = None,
log_prob_key: NestedKey = "sample_log_prob",
entropy_key: NestedKey = "entropy",
inplace: bool | None = None,
include_sum: bool | None = None,
) -> CompositeDistribution:
"""Create a `CompositeDistribution` instance from existing distribution objects.

This class method allows for the creation of a `CompositeDistribution` by directly providing
a dictionary of distribution instances, rather than specifying distribution types and parameters separately.

Args:
params (TensorDictBase): A TensorDict that defines the batch shape for the composite distribution.
The params will not be used by this method, but the tensordict will be used to gather the key names of
the distributions.
distributions (Dict[NestedKey, d.Distribution]): A dictionary mapping nested keys to distribution instances.
These distributions will be used directly in the composite distribution.

Keyword Args:
name_map (Dict[NestedKey, NestedKey], optional): A mapping of where each sample should be written. If not provided,
the key names from `distribution_map` will be used.
aggregate_probabilities (bool, optional): If `True`, the `log_prob` and `entropy` methods will sum the probabilities and entropies
of the individual distributions and return a single tensor. If `False`, individual log-probabilities will be stored in the input
TensorDict (for `log_prob`) or returned as leaves of the output TensorDict (for `entropy`). This can be overridden at runtime
by passing the `aggregate_probabilities` argument to `log_prob` and `entropy`. Defaults to `False`.
log_prob_key (NestedKey, optional): The key where the log probability will be stored. Defaults to `'sample_log_prob'`.
entropy_key (NestedKey, optional): The key where the entropy will be stored. Defaults to `'entropy'`.
inplace (bool, optional): Whether to modify the input TensorDict in-place. Defaults to `True`.

.. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor.

include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. Defaults to `True`.

.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.

Returns:
CompositeDistribution: An instance of `CompositeDistribution` initialized with the provided distributions.

Raises:
KeyError: If a key in `name_map` cannot be found in the provided distributions.

.. note:: The batch size of the `params` TensorDict determines the batch shape of the composite distribution.

Example:
>>> from tensordict.nn import CompositeDistribution, ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule, TensorDictModule
>>> import torch
>>> from tensordict import TensorDict
>>>
>>> # Values are not used to build the dists
>>> params = TensorDict({("0", "loc"): None, ("1", "loc"): None, ("0", "scale"): None, ("1", "scale"): None})
>>> d0 = torch.distributions.Normal(0, 1)
>>> d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2))
>>>
>>> d = CompositeDistribution.from_distributions(params, {"0": d0, "1": d1})
>>> print(d.sample())
TensorDict(
fields={
0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
self = cls.__new__(cls)
self._batch_shape = params.shape
dists = {}
if name_map is not None:
name_map = {
unravel_key(key): unravel_key(other_key)
for key, other_key in name_map.items()
}
for name, dist in distributions.items():
name_unravel = unravel_key(name)
if name_map:
try:
write_name = unravel_key(name_map.get(name, name_unravel))
except KeyError:
raise KeyError(
f"Failed to retrieve the key {name} from the name_map with keys {name_map.keys()}."
)
else:
write_name = name_unravel
dists[write_name] = dist
self.dists = dists
self.log_prob_key = log_prob_key
self.entropy_key = entropy_key

self.aggregate_probabilities = aggregate_probabilities
self.include_sum = include_sum
self.inplace = inplace
return self

@property
def aggregate_probabilities(self):
aggregate_probabilities = self._aggregate_probabilities
Expand Down
23 changes: 23 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,29 @@ def test_sample(self):
sample = dist.sample((4,))
assert sample.shape == torch.Size((4,) + params.shape)

def test_from_distributions(self):

# Values are not used to build the dists
params = TensorDict(
{
("0", "loc"): None,
("1", "nested", "loc"): None,
("0", "scale"): None,
("1", "nested", "scale"): None,
}
)
d0 = torch.distributions.Normal(0, 1)
d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2))

d = CompositeDistribution.from_distributions(
params, {"0": d0, ("1", "nested"): d1}
)
s = d.sample()
assert s["0"].shape == ()
assert s["1", "nested"].shape == (1, 2)
assert isinstance(s["0"], torch.Tensor)
assert isinstance(s["1", "nested"], torch.Tensor)

def test_sample_named(self):
params = TensorDict(
{
Expand Down
Loading