From 98c57eefd2df092584c87e5821ec910a3f80a03b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 18:28:50 +0000 Subject: [PATCH] [Feature] Force log_prob to return a tensordict when kwargs are passed to ProbabilisticTensorDictSequential.log_prob ghstack-source-id: 326d0763c9bbb13b51daac91edca4f0e821adf62 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1146 --- tensordict/nn/probabilistic.py | 48 +++++++++++++++++++++++++++++----- test/test_nn.py | 24 +++++++++++++++-- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 04fdfb6bb..a9b2cdf5a 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -422,7 +422,16 @@ def log_prob( if dist.aggregate_probabilities is not None: aggregate_probabilities_inp = dist.aggregate_probabilities else: - # TODO: warning + warnings.warn( + f"aggregate_probabilities wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the aggregate_probability will be `True` in this case but in a future release " + f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such " + f"that log_prob will return a tensordict with the log-prob values. To silence this warning, " + f"pass `aggregate_probabilities` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) aggregate_probabilities_inp = False else: aggregate_probabilities_inp = aggregate_probabilities @@ -430,13 +439,32 @@ def log_prob( if dist.inplace is not None: inplace = dist.inplace else: - # TODO: warning + warnings.warn( + f"inplace wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the `inplace` will be `True` in this case but in a future release " + f"(v0.9) this will change and `inplace` will default to ``False`` such " + f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, " + f"pass `inplace` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) inplace = True if include_sum is None: if dist.include_sum is not None: include_sum = dist.include_sum else: - # TODO: warning + warnings.warn( + f"include_sum wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the `include_sum` will be `True` in this case but in a future release " + f"(v0.9) this will change and `include_sum` will default to ``False`` such " + f"that log_prob will return a new tensordict containing only the leaf log-prob values. " + f"To silence this warning, " + f"pass `include_sum` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) include_sum = True lp = dist.log_prob( tensordict, @@ -446,6 +474,7 @@ def log_prob( ) if is_tensor_collection(lp) and aggregate_probabilities is None: return lp.get(dist.log_prob_key) + return lp else: return dist.log_prob(tensordict.get(self.out_keys[0])) @@ -1027,8 +1056,9 @@ def log_prob( ): """Returns the log-probability of the input tensordict. - If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`, - this method will return the log-probability of the entire composite distribution. + If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`, + or if any of :attr:`aggregate_probabilities`, :attr:`inplace` or :attr:`include_sum` this method will return + the log-probability of the entire composite distribution. Otherwise, it will only consider the last probabilistic module in the sequence. @@ -1069,7 +1099,13 @@ def log_prob( tensordict_inp = tensordict if dist is None: dist = self.get_dist(tensordict_inp) - if self.return_composite and isinstance(dist, CompositeDistribution): + return_composite = ( + self.return_composite + or (aggregate_probabilities is not None) + or (inplace is not None) + or (include_sum is not None) + ) + if return_composite and isinstance(dist, CompositeDistribution): # Check the values within the dist - if not set, choose defaults if aggregate_probabilities is None: if self.aggregate_probabilities is not None: diff --git a/test/test_nn.py b/test/test_nn.py index 948af946b..af087463b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2930,7 +2930,17 @@ def test_prob_module(self, interaction, return_log_prob, map_names): assert key_logprob1 in sample assert all(key in sample for key in module.out_keys) sample_clone = sample.clone() - lp = module.log_prob(sample_clone) + with pytest.warns( + DeprecationWarning, + match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule", + ), pytest.warns( + DeprecationWarning, + match="inplace wasn't defined in the ProbabilisticTensorDictModule", + ), pytest.warns( + DeprecationWarning, + match="include_sum wasn't defined in the ProbabilisticTensorDictModule", + ): + lp = module.log_prob(sample_clone) assert isinstance(lp, torch.Tensor) if return_log_prob: torch.testing.assert_close( @@ -3077,7 +3087,17 @@ def test_prob_module_seq(self, interaction, return_log_prob, ordereddict): assert isinstance(dist, CompositeDistribution) sample_clone = sample.clone() - lp = module.log_prob(sample_clone) + with pytest.warns( + DeprecationWarning, + match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule", + ), pytest.warns( + DeprecationWarning, + match="inplace wasn't defined in the ProbabilisticTensorDictModule", + ), pytest.warns( + DeprecationWarning, + match="include_sum wasn't defined in the ProbabilisticTensorDictModule", + ): + lp = module.log_prob(sample_clone) if return_log_prob: torch.testing.assert_close(