Skip to content

Commit

Permalink
[Mosaic:TPU] Enable broadcast from 1-D vectors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700592669
tlongeri authored and Google-ML-Automation committed Nov 27, 2024

Verified

This commit was signed with the committer’s verified signature.
maleck13 Craig Brookes
1 parent 47d1960 commit 7a2070e
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
@@ -1110,12 +1110,10 @@ class VectorLayoutInferer {
return success();
}
if (auto src_ty = dyn_cast<VectorType>(some_src_ty)) {
TPU_CHECK_OP(src_ty.getRank() >= 2, "source rank below 2D unsupported");
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
auto some_layout = getLayout(op.getSource());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() != ImplicitDim::kNone) {
if (layout.implicit_dim() != ImplicitDim::kNone && src_ty.getRank() > 1) {
VectorLayout layout_2d(layout.bitwidth(), layout.offsets(),
layout.tiling(), ImplicitDim::kNone);
if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) {

0 comments on commit 7a2070e

Please sign in to comment.