Skip to content

Commit

Permalink
fix foward for D51601041
Browse files Browse the repository at this point in the history
Summary: as title

Differential Revision: D51951478
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Dec 7, 2023
1 parent a1b61b3 commit 7ad4c07
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions torchrec/modules/tests/test_mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import cast
from typing import cast, List, Optional

import torch
from dper_lib.fx_tracing import DperTracer
from dper_lib.fx_tracing.graph_module import DperGraphModule
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
Expand All @@ -23,6 +21,22 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class Tracer(torch.fx.Tracer):
_leaf_module_names: List[str]

def __init__(self, leaf_module_names: Optional[List[str]] = None) -> None:
super().__init__()
self._leaf_module_names = leaf_module_names or []

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if (
type(m).__name__ in self._leaf_module_names
or module_qualified_name in self._leaf_module_names
):
return True
return super().is_leaf_module(m, module_qualified_name)


class MCHManagedCollisionEmbeddingBagCollectionTest(unittest.TestCase):
def test_zch_ebc_train(self) -> None:
device = torch.device("cpu")
Expand Down Expand Up @@ -288,8 +302,15 @@ def test_mc_collection_traceable(self) -> None:
# pyre-ignore[6]
embedding_configs=embedding_configs,
)
graph: torch.fx.Graph = DperTracer().trace(mcc)
gm: torch.fx.GraphModule = DperGraphModule(mcc, graph)

mcc = ManagedCollisionCollection(
managed_collision_modules=mc_modules,
# pyre-ignore[6]
embedding_configs=embedding_configs,
)
trec_tracer = Tracer(["ComputeJTDictToKJT"])
graph: torch.fx.Graph = trec_tracer.trace(mcc)
gm: torch.fx.GraphModule = torch.fx.GraphModule(mcc, graph)
gm.print_readable()

# TODO: since this is unsharded module, also check torch.jit.script
Expand Down

0 comments on commit 7ad4c07

Please sign in to comment.