Skip to content

Commit

Permalink
makes embedding dim mismatch cases easier to debug (pytorch#1416)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1416

Reviewed By: YLGH

Differential Revision: D49806973

fbshipit-source-id: 8618ac3f60cd9899abdafb6bc569c783dc225a41
  • Loading branch information
Jiaqi Zhai authored and facebook-github-bot committed Oct 3, 2023
1 parent 9b3e0c0 commit 420a691
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def __init__( # noqa C901
if self._embedding_dim != config.embedding_dim:
raise ValueError(
"All tables in a EmbeddingCollection are required to have same embedding dimension."
+ f" Violating case: {config.name}'s embedding_dim {config.embedding_dim} !="
+ f" {self._embedding_dim}"
)
dtype = (
torch.float32 if config.data_type == DataType.FP32 else torch.float16
Expand Down
2 changes: 2 additions & 0 deletions torchrec/modules/fused_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ def __init__(
elif self._embedding_dim != table.embedding_dim:
raise ValueError(
"All tables in a EmbeddingCollection are required to have same embedding dimension."
+ f" Violating case: {table}'s embedding_dim {table.embedding_dim} !="
+ f" {self._embedding_dim}"
)
for feature in table.feature_names:
if feature in seen_features:
Expand Down
2 changes: 2 additions & 0 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ def __init__( # noqa C901
if self._embedding_dim != config.embedding_dim:
raise ValueError(
"All tables in a EmbeddingCollection are required to have same embedding dimension."
+ f" Violating case: {config.name}'s embedding_dim {config.embedding_dim} !="
+ f" {self._embedding_dim}"
)
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand Down

0 comments on commit 420a691

Please sign in to comment.