diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 57d64ad78..549b4d12a 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -119,8 +119,6 @@ def set_tensor_dict( # noqa: F811 tree_unflatten, ) -_has_functorch = True - class _exclude_td_from_pytree: def __init__(self): @@ -144,8 +142,7 @@ def unset(self): self.__exit__(None, None, None) -# Monkey-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked -if _has_functorch: +if not strtobool(os.getenv("PYTORCH_TENSORDICT_IMPORT_VMAP", "False")): # Monkey-patches def _process_batched_inputs( diff --git a/test/test_nn.py b/test/test_nn.py index 580f115b0..630b8d3d2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2129,6 +2129,7 @@ def test_init(self): p0, p1 ), f"Ensemble params were not initialized correctly {p0}, {p1}" + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize( "net", [ @@ -2154,6 +2155,7 @@ def test_siso_forward(self, net): outs = out["dork"].unbind(0) assert not torch.allclose(outs[0], outs[1]), "Outputs should be different" + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize( "net", [ diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f613e6a13..812a69f7d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -99,6 +99,7 @@ _has_onnx = importlib.util.find_spec("onnxruntime", None) is not None _v2_5 = TORCH_VERSION >= version.parse("2.5.0") +PYTORCH_TEST_FBCODE = os.getenv("PYTORCH_TEST_FBCODE") _IS_OSX = platform.system() == "Darwin" _IS_WINDOWS = sys.platform == "win32" @@ -1151,6 +1152,7 @@ def remover(module, *args, **kwargs): sd = net.state_dict() assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize("as_module", [False, True]) @pytest.mark.parametrize("lazy_stack", [False, True]) def test_from_modules(self, as_module, lazy_stack): @@ -1196,6 +1198,7 @@ def get_leaf(leaf): assert p.grad is None assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize("as_module", [False, True]) def test_from_modules_expand(self, as_module): empty_module = nn.Sequential( @@ -1229,6 +1232,7 @@ def exec_module(params, x): if isinstance(param, nn.Parameter) ) + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize("as_module", [False, True]) @pytest.mark.parametrize("lazy_stack", [False, True]) @pytest.mark.parametrize("device", get_available_devices()) @@ -3397,6 +3401,7 @@ def test_clamp_max_default(self): TestTensorDictsBase.TYPES_DEVICES, ) class TestTensorDicts(TestTensorDictsBase): + @pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode") @pytest.mark.parametrize("nested", [False, True]) def test_add_batch_dim_cache(self, td_name, device, nested): td = getattr(self, td_name)(device)