diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 4285624f3..79a890545 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -414,7 +414,8 @@ def log_prob_composite( "The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.", category=DeprecationWarning, ) - slp = 0.0 + if include_sum: + slp = 0.0 d = {} for name, dist in self.dists.items(): d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) diff --git a/test/test_nn.py b/test/test_nn.py index 13231e927..cdbab76e9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12,6 +12,7 @@ import pytest import torch + from tensordict import NonTensorData, NonTensorStack, tensorclass, TensorDict from tensordict._C import unravel_key_list from tensordict.nn import ( @@ -2277,7 +2278,9 @@ def test_log_prob(self): assert isinstance(lp, torch.Tensor) assert lp.requires_grad - def test_log_prob_composite(self): + @pytest.mark.parametrize("inplace", [None, True, False]) + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_log_prob_composite(self, inplace, include_sum): params = TensorDict( { "cont": { @@ -2296,12 +2299,25 @@ def test_log_prob_composite(self): }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, aggregate_probabilities=False, + inplace=inplace, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True + if inplace is None: + inplace = True sample = dist.rsample((4,)) - sample = dist.log_prob_composite(sample, include_sum=True) - assert sample.get("cont_log_prob").requires_grad - assert sample.get(("nested", "disc_log_prob")).requires_grad - assert "sample_log_prob" in sample.keys() + sample_lp = dist.log_prob_composite(sample) + assert sample_lp.get("cont_log_prob").requires_grad + assert sample_lp.get(("nested", "disc_log_prob")).requires_grad + if inplace: + assert sample_lp is sample + else: + assert sample_lp is not sample + if include_sum: + assert "sample_log_prob" in sample_lp.keys() + else: + assert "sample_log_prob" not in sample_lp.keys() def test_entropy(self): params = TensorDict( @@ -2327,7 +2343,8 @@ def test_entropy(self): assert isinstance(ent, torch.Tensor) assert ent.requires_grad - def test_entropy_composite(self): + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_entropy_composite(self, include_sum): params = TensorDict( { "cont": { @@ -2345,12 +2362,18 @@ def test_entropy_composite(self): ("nested", "disc"): distributions.Categorical, }, aggregate_probabilities=False, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True sample = dist.entropy() assert sample.shape == params.shape == dist._batch_shape assert sample.get("cont_entropy").requires_grad assert sample.get(("nested", "disc_entropy")).requires_grad - assert "entropy" in sample.keys() + if include_sum: + assert "entropy" in sample.keys() + else: + assert "entropy" not in sample.keys() def test_cdf(self): params = TensorDict(