Skip to content

Commit

Permalink
[FBCode] Deactivate vmap monkey-patching in FBCode
Browse files Browse the repository at this point in the history
ghstack-source-id: 4f1ebfd0fd4ff5b7378c8692a064406e72fc68c0
Pull Request resolved: #1135
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent 50ea33b commit d91fd1c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
5 changes: 1 addition & 4 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def set_tensor_dict( # noqa: F811
tree_unflatten,
)

_has_functorch = True


class _exclude_td_from_pytree:
def __init__(self):
Expand All @@ -144,8 +142,7 @@ def unset(self):
self.__exit__(None, None, None)


# Monkey-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked
if _has_functorch:
if not strtobool(os.getenv("PYTORCH_TENSORDICT_IMPORT_VMAP", "False")):
# Monkey-patches

def _process_batched_inputs(
Expand Down
2 changes: 2 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,7 @@ def test_init(self):
p0, p1
), f"Ensemble params were not initialized correctly {p0}, {p1}"

@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize(
"net",
[
Expand All @@ -2154,6 +2155,7 @@ def test_siso_forward(self, net):
outs = out["dork"].unbind(0)
assert not torch.allclose(outs[0], outs[1]), "Outputs should be different"

@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize(
"net",
[
Expand Down
5 changes: 5 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None

_v2_5 = TORCH_VERSION >= version.parse("2.5.0")
PYTORCH_TEST_FBCODE = os.getenv("PYTORCH_TEST_FBCODE")

_IS_OSX = platform.system() == "Darwin"
_IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -1151,6 +1152,7 @@ def remover(module, *args, **kwargs):
sd = net.state_dict()
assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, []))

@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize("as_module", [False, True])
@pytest.mark.parametrize("lazy_stack", [False, True])
def test_from_modules(self, as_module, lazy_stack):
Expand Down Expand Up @@ -1196,6 +1198,7 @@ def get_leaf(leaf):
assert p.grad is None
assert all(param.grad is not None for param in params.values(True, True))

@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize("as_module", [False, True])
def test_from_modules_expand(self, as_module):
empty_module = nn.Sequential(
Expand Down Expand Up @@ -1229,6 +1232,7 @@ def exec_module(params, x):
if isinstance(param, nn.Parameter)
)

@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize("as_module", [False, True])
@pytest.mark.parametrize("lazy_stack", [False, True])
@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -3397,6 +3401,7 @@ def test_clamp_max_default(self):
TestTensorDictsBase.TYPES_DEVICES,
)
class TestTensorDicts(TestTensorDictsBase):
@pytest.mark.skipif(PYTORCH_TEST_FBCODE, reason="vmap now working in fbcode")
@pytest.mark.parametrize("nested", [False, True])
def test_add_batch_dim_cache(self, td_name, device, nested):
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit d91fd1c

Please sign in to comment.