Skip to content

Commit

Permalink
TorchToTosa: aten.embedding: Allow indices with any rank (#2327)
Browse files Browse the repository at this point in the history
It's actually fine to not check the rank of the indices, because the conversion anyways flattens the index tensor to be (1, numElements) before applying tosa::gather, and then anyways reshapes the output tensor to the output shape of the aten.embedding.
  • Loading branch information
mgehre-amd authored Jul 21, 2023
1 parent 1e468e8 commit 3ca35b4
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModule1DIndices_basic",
"TModuleRank2_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
Expand Down
3 changes: 0 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2961,9 +2961,6 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Indices must be of integer tensor type");

if (indicesType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "indices must be of rank 2");

auto weightType = weight.getType().cast<RankedTensorType>();
if (weightType.getRank() != 2)
return op.emitError("weight must be of rank 2");
Expand Down

0 comments on commit 3ca35b4

Please sign in to comment.