Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
2 parents 330fb5e + 9a19a57 commit 0b1f8b3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
10 changes: 9 additions & 1 deletion tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,15 @@ def forward(
if self._select_before_return:
# We must also update any value that has been updated during the course of execution
# from the input data.
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
if is_compiling():
keys = [ # noqa: C416
k
for k in {k for k in self.out_keys}.union( # noqa: C416
{k for k in tensordict.keys(True, True)} # noqa: C416
)
]
else:
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result

Expand Down
14 changes: 13 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
)
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling

__all__ = ["TensorDictSequential"]

Expand Down Expand Up @@ -491,7 +495,15 @@ def forward(
if self._select_before_return:
# We must also update any value that has been updated during the course of execution
# from the input data.
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
if is_compiling():
keys = [ # noqa: C416
k
for k in {k for k in self.out_keys}.union( # noqa: C416
{k for k in tensordict.keys(True, True)} # noqa: C416
)
]
else:
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result

Expand Down

0 comments on commit 0b1f8b3

Please sign in to comment.