Skip to content

Commit

Permalink
[Feature] Force log_prob to return a tensordict when kwargs are passe…
Browse files Browse the repository at this point in the history
…d to ProbabilisticTensorDictSequential.log_prob

ghstack-source-id: 326d0763c9bbb13b51daac91edca4f0e821adf62
Pull Request resolved: #1146
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent 2d37d92 commit 98c57ee
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
48 changes: 42 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,49 @@ def log_prob(
if dist.aggregate_probabilities is not None:
aggregate_probabilities_inp = dist.aggregate_probabilities
else:
# TODO: warning
warnings.warn(
f"aggregate_probabilities wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the aggregate_probability will be `True` in this case but in a future release "
f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such "
f"that log_prob will return a tensordict with the log-prob values. To silence this warning, "
f"pass `aggregate_probabilities` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
aggregate_probabilities_inp = False
else:
aggregate_probabilities_inp = aggregate_probabilities
if inplace is None:
if dist.inplace is not None:
inplace = dist.inplace
else:
# TODO: warning
warnings.warn(
f"inplace wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the `inplace` will be `True` in this case but in a future release "
f"(v0.9) this will change and `inplace` will default to ``False`` such "
f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, "
f"pass `inplace` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
inplace = True
if include_sum is None:
if dist.include_sum is not None:
include_sum = dist.include_sum
else:
# TODO: warning
warnings.warn(
f"include_sum wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the `include_sum` will be `True` in this case but in a future release "
f"(v0.9) this will change and `include_sum` will default to ``False`` such "
f"that log_prob will return a new tensordict containing only the leaf log-prob values. "
f"To silence this warning, "
f"pass `include_sum` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
include_sum = True
lp = dist.log_prob(
tensordict,
Expand All @@ -446,6 +474,7 @@ def log_prob(
)
if is_tensor_collection(lp) and aggregate_probabilities is None:
return lp.get(dist.log_prob_key)
return lp
else:
return dist.log_prob(tensordict.get(self.out_keys[0]))

Expand Down Expand Up @@ -1027,8 +1056,9 @@ def log_prob(
):
"""Returns the log-probability of the input tensordict.
If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
this method will return the log-probability of the entire composite distribution.
If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
or if any of :attr:`aggregate_probabilities`, :attr:`inplace` or :attr:`include_sum` this method will return
the log-probability of the entire composite distribution.
Otherwise, it will only consider the last probabilistic module in the sequence.
Expand Down Expand Up @@ -1069,7 +1099,13 @@ def log_prob(
tensordict_inp = tensordict
if dist is None:
dist = self.get_dist(tensordict_inp)
if self.return_composite and isinstance(dist, CompositeDistribution):
return_composite = (
self.return_composite
or (aggregate_probabilities is not None)
or (inplace is not None)
or (include_sum is not None)
)
if return_composite and isinstance(dist, CompositeDistribution):
# Check the values within the dist - if not set, choose defaults
if aggregate_probabilities is None:
if self.aggregate_probabilities is not None:
Expand Down
24 changes: 22 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2930,7 +2930,17 @@ def test_prob_module(self, interaction, return_log_prob, map_names):
assert key_logprob1 in sample
assert all(key in sample for key in module.out_keys)
sample_clone = sample.clone()
lp = module.log_prob(sample_clone)
with pytest.warns(
DeprecationWarning,
match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule",
), pytest.warns(
DeprecationWarning,
match="inplace wasn't defined in the ProbabilisticTensorDictModule",
), pytest.warns(
DeprecationWarning,
match="include_sum wasn't defined in the ProbabilisticTensorDictModule",
):
lp = module.log_prob(sample_clone)
assert isinstance(lp, torch.Tensor)
if return_log_prob:
torch.testing.assert_close(
Expand Down Expand Up @@ -3077,7 +3087,17 @@ def test_prob_module_seq(self, interaction, return_log_prob, ordereddict):
assert isinstance(dist, CompositeDistribution)

sample_clone = sample.clone()
lp = module.log_prob(sample_clone)
with pytest.warns(
DeprecationWarning,
match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule",
), pytest.warns(
DeprecationWarning,
match="inplace wasn't defined in the ProbabilisticTensorDictModule",
), pytest.warns(
DeprecationWarning,
match="include_sum wasn't defined in the ProbabilisticTensorDictModule",
):
lp = module.log_prob(sample_clone)

if return_log_prob:
torch.testing.assert_close(
Expand Down

1 comment on commit 98c57ee

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 98c57ee Previous: eaafc18 Ratio
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.0697397291532058 iter/sec (stddev: 0.13124384881990853) 2.5281468358521666 iter/sec (stddev: 0.05007958269568729) 2.36

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.