diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ecb90a73d..8a3fcb855 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -1119,7 +1119,18 @@ def forward( else: result = tensordict_exec if self._select_before_return: - return tensordict.update(result, keys_to_update=self.out_keys) + # We must also update any value that has been updated during the course of execution + # from the input data. + if is_compiling(): + keys = [ # noqa: C416 + k + for k in {k for k in self.out_keys}.union( # noqa: C416 + {k for k in tensordict.keys(True, True)} # noqa: C416 + ) + ] + else: + keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) + return tensordict.update(result, keys_to_update=keys) return result diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index df9ecd1e9..faa2f60a0 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -35,6 +35,10 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling __all__ = ["TensorDictSequential"] @@ -491,7 +495,15 @@ def forward( if self._select_before_return: # We must also update any value that has been updated during the course of execution # from the input data. - keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) + if is_compiling(): + keys = [ # noqa: C416 + k + for k in {k for k in self.out_keys}.union( # noqa: C416 + {k for k in tensordict.keys(True, True)} # noqa: C416 + ) + ] + else: + keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) return tensordict.update(result, keys_to_update=keys) return result