Skip to content

Commit

Permalink
Fix broken two tower test
Browse files Browse the repository at this point in the history
Summary: Fixing test_two_tower_retrieval.

Differential Revision: D49957751
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Oct 5, 2023
1 parent 369526d commit ffc0caa
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
12 changes: 10 additions & 2 deletions examples/retrieval/modules/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
layer_sizes: List[int],
k: int,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim
Expand All @@ -186,10 +187,16 @@ def __init__(
self.query_ebc = query_ebc
self.candidate_ebc = candidate_ebc
self.query_proj = MLP(
in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
in_size=self.embedding_dim,
layer_sizes=layer_sizes,
device=device,
dtype=dtype,
)
self.candidate_proj = MLP(
in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
in_size=self.embedding_dim,
layer_sizes=layer_sizes,
device=device,
dtype=dtype,
)
self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index
self.k = k
Expand All @@ -212,6 +219,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
candidates = torch.empty(
(batch_size, self.k), device=self.device, dtype=torch.int64
)
query_embedding = query_embedding.to(torch.float32) # required by faiss
self.faiss_index.search(query_embedding, self.k, distances, candidates)

# candidate lookup
Expand Down
1 change: 1 addition & 0 deletions examples/retrieval/tests/test_two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class InferTest(unittest.TestCase):
"this test requires a GPU",
)
def test_infer_function(self) -> None:
assert torch.cuda.device_count() >= 2
infer(
embedding_dim=16,
layer_sizes=[16],
Expand Down
12 changes: 9 additions & 3 deletions examples/retrieval/two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ShardingEnv, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand Down Expand Up @@ -116,6 +116,7 @@ def infer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
feature_names=[feature_name],
data_type=DataType.FP16,
)
ebcs.append(
EmbeddingBagCollection(
Expand Down Expand Up @@ -156,7 +157,9 @@ def infer(
index.train(embeddings)
index.add(embeddings)

retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device)
retrieval_model = TwoTowerRetrieval(
index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
)

constraints = {}
for feature_name in two_tower_column_names:
Expand All @@ -166,7 +169,10 @@ def infer(
)

quant_model = trec_infer.modules.quantize_embeddings(
retrieval_model, dtype=torch.qint8, inplace=True
retrieval_model,
dtype=torch.qint8,
inplace=True,
output_dtype=torch.float16,
)

dmp = DistributedModelParallel(
Expand Down
9 changes: 8 additions & 1 deletion torchrec/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ def __init__(
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
self._out_size = out_size
self._in_size = in_size
self._linear: nn.Linear = nn.Linear(
self._in_size, self._out_size, bias=bias, device=device
self._in_size,
self._out_size,
bias=bias,
device=device,
dtype=dtype,
)
self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation

Expand Down Expand Up @@ -120,6 +125,7 @@ def __init__(
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()

Expand All @@ -137,6 +143,7 @@ def __init__(
bias=bias,
activation=extract_module_or_tensor_callable(activation),
device=device,
dtype=dtype,
)
for i in range(len(layer_sizes))
]
Expand Down

0 comments on commit ffc0caa

Please sign in to comment.