Skip to content

Commit

Permalink
Fix dtype mismatch on meta (pytorch#1689)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1689

Fix the issue described in D52591217.

Reviewed By: dstaay-fb, silverlakeli

Differential Revision: D53232739

fbshipit-source-id: 9e290a12e6b1dad283f687c7ae6c7f6db43fabc9
  • Loading branch information
ge0405 authored and facebook-github-bot committed Feb 9, 2024
1 parent d6b3da6 commit 47441cf
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
self._device: torch.device = (
device if device is not None else torch.device("cpu")
)
self._dtypes: List[int] = []

table_names = set()
for embedding_config in tables:
Expand All @@ -182,6 +183,7 @@ def __init__(
include_last_offset=True,
dtype=dtype,
)
self._dtypes.append(embedding_config.data_type.value)

if not embedding_config.feature_names:
embedding_config.feature_names = [embedding_config.name]
Expand Down Expand Up @@ -217,10 +219,19 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
per_sample_weights: Optional[torch.Tensor] = None
if self._is_weighted:
per_sample_weights = (
f.weights().half()
if self._dtypes[i] == DataType.FP16.value
else f.weights()
)
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
per_sample_weights=per_sample_weights
if self._is_weighted
else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
Expand Down

0 comments on commit 47441cf

Please sign in to comment.