diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 4aad004df0f7..e154f5cb92ef 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1726,7 +1726,7 @@ class ConvertAtenScaledDotProductAttentionOp } // Broadcast the batch dimensions of the mask: - if (mask) { + if (!isa(mask.getType())) { auto maskTy = cast(mask.getType()); int64_t rank = maskTy.getRank(); bool needsBroadcast = false;