From 7ad4c070b993d2c6c8cc1f4ce957d4cd1fc81a09 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Thu, 7 Dec 2023 11:27:16 -0800 Subject: [PATCH] fix foward for D51601041 Summary: as title Differential Revision: D51951478 --- .../tests/test_mc_embedding_modules.py | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index 56369b4e4..a566b7649 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -6,11 +6,9 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import cast +from typing import cast, List, Optional import torch -from dper_lib.fx_tracing import DperTracer -from dper_lib.fx_tracing.graph_module import DperGraphModule from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection @@ -23,6 +21,22 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +class Tracer(torch.fx.Tracer): + _leaf_module_names: List[str] + + def __init__(self, leaf_module_names: Optional[List[str]] = None) -> None: + super().__init__() + self._leaf_module_names = leaf_module_names or [] + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + if ( + type(m).__name__ in self._leaf_module_names + or module_qualified_name in self._leaf_module_names + ): + return True + return super().is_leaf_module(m, module_qualified_name) + + class MCHManagedCollisionEmbeddingBagCollectionTest(unittest.TestCase): def test_zch_ebc_train(self) -> None: device = torch.device("cpu") @@ -288,8 +302,15 @@ def test_mc_collection_traceable(self) -> None: # pyre-ignore[6] embedding_configs=embedding_configs, ) - graph: torch.fx.Graph = DperTracer().trace(mcc) - gm: torch.fx.GraphModule = DperGraphModule(mcc, graph) + + mcc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + # pyre-ignore[6] + embedding_configs=embedding_configs, + ) + trec_tracer = Tracer(["ComputeJTDictToKJT"]) + graph: torch.fx.Graph = trec_tracer.trace(mcc) + gm: torch.fx.GraphModule = torch.fx.GraphModule(mcc, graph) gm.print_readable() # TODO: since this is unsharded module, also check torch.jit.script