From ffc0caaeb200133284fd5666873a49b21e146d13 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 5 Oct 2023 13:51:07 -0700 Subject: [PATCH] Fix broken two tower test Summary: Fixing test_two_tower_retrieval. Differential Revision: D49957751 --- examples/retrieval/modules/two_tower.py | 12 ++++++++++-- examples/retrieval/tests/test_two_tower_retrieval.py | 1 + examples/retrieval/two_tower_retrieval.py | 12 +++++++++--- torchrec/modules/mlp.py | 9 ++++++++- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py index cb1ac1954..224bcbfc7 100644 --- a/examples/retrieval/modules/two_tower.py +++ b/examples/retrieval/modules/two_tower.py @@ -174,6 +174,7 @@ def __init__( layer_sizes: List[int], k: int, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim @@ -186,10 +187,16 @@ def __init__( self.query_ebc = query_ebc self.candidate_ebc = candidate_ebc self.query_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.candidate_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index self.k = k @@ -212,6 +219,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor: candidates = torch.empty( (batch_size, self.k), device=self.device, dtype=torch.int64 ) + query_embedding = query_embedding.to(torch.float32) # required by faiss self.faiss_index.search(query_embedding, self.k, distances, candidates) # candidate lookup diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py index eef1ef455..79fedc009 100644 --- a/examples/retrieval/tests/test_two_tower_retrieval.py +++ b/examples/retrieval/tests/test_two_tower_retrieval.py @@ -23,6 +23,7 @@ class InferTest(unittest.TestCase): "this test requires a GPU", ) def test_infer_function(self) -> None: + assert torch.cuda.device_count() >= 2 infer( embedding_dim=16, layer_sizes=[16], diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py index b1b4ccb49..4b1ff3a4f 100644 --- a/examples/retrieval/two_tower_retrieval.py +++ b/examples/retrieval/two_tower_retrieval.py @@ -18,7 +18,7 @@ from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner.types import ParameterConstraints from torchrec.distributed.types import ShardingEnv, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -116,6 +116,7 @@ def infer( embedding_dim=embedding_dim, num_embeddings=num_embeddings, feature_names=[feature_name], + data_type=DataType.FP16, ) ebcs.append( EmbeddingBagCollection( @@ -156,7 +157,9 @@ def infer( index.train(embeddings) index.add(embeddings) - retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device) + retrieval_model = TwoTowerRetrieval( + index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16 + ) constraints = {} for feature_name in two_tower_column_names: @@ -166,7 +169,10 @@ def infer( ) quant_model = trec_infer.modules.quantize_embeddings( - retrieval_model, dtype=torch.qint8, inplace=True + retrieval_model, + dtype=torch.qint8, + inplace=True, + output_dtype=torch.float16, ) dmp = DistributedModelParallel( diff --git a/torchrec/modules/mlp.py b/torchrec/modules/mlp.py index c369b24c3..41685b341 100644 --- a/torchrec/modules/mlp.py +++ b/torchrec/modules/mlp.py @@ -50,13 +50,18 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") self._out_size = out_size self._in_size = in_size self._linear: nn.Linear = nn.Linear( - self._in_size, self._out_size, bias=bias, device=device + self._in_size, + self._out_size, + bias=bias, + device=device, + dtype=dtype, ) self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation @@ -120,6 +125,7 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() @@ -137,6 +143,7 @@ def __init__( bias=bias, activation=extract_module_or_tensor_callable(activation), device=device, + dtype=dtype, ) for i in range(len(layer_sizes)) ]