diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 6edad713b17a..2c45be62fa7d 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -252,6 +252,17 @@ class VectorLayout { int8_t bitwidth() const { return bitwidth_; } const LayoutOffsets &offsets() const { return offsets_; } + const LayoutOffsets getCanonicalOffsets( + const ArrayRef shape, + const std::array target_shape) const { + // For (1, n) tiling with a single row, 2nd minor replication does not + // change anything about the layout - it is equivalent to an offset of 0. + // We choose a replicated offset as "canonical". + const std::array tiled_ishape = getImplicitTiledDims(shape, 1); + return { + (tiling_[0] == 1 && tiled_ishape[0] == 1) ? std::nullopt : offsets_[0], + offsets_[1]}; + } const std::array &tiling() const { return tiling_; } ImplicitDim implicit_dim() const { return implicit_dim_; } int packing() const { return 32 / bitwidth_; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 7ca4204db343..6414b8eadf4d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6164,42 +6164,83 @@ FailureOr>> changeTiling( if (src_tiling == dst_tiling) { return std::pair(src, std::move(vregs)); } + const LayoutOffsets src_offsets = + src.getCanonicalOffsets(vty.getShape(), ctx.target_shape); + const std::array tiled_ishape = + src.getImplicitTiledDims(vty.getShape(), 1); const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); - // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating - // sublanes. - if (try_replicate_rows && packing == 1 && - *(vregs.dimensions().end() - 2) == 1 && - src.tiling() == std::array{1, ctx.target_shape[1]} && - dst_tiling == ctx.target_shape) { - DCHECK_EQ(src.offsets()[0].value_or(0), 0); + const std::array dst_vreg_slice = + VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling); + + // Fully replicated offsets are handled efficiently elsewhere (in relayout) + CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value()); + + // Handle replicating small-to-large retiling for (a) replicated 2nd minor or + // (b) 32-bit single-row. + // This retiling is one-to-many vregs. + // TODO(tlongeri): Large-to-small retiling with replicated minor is analogous + // to this. + if (src_tiling[1] == ctx.target_shape[1] && + dst_tiling[1] == ctx.target_shape[1] && + dst_tiling[0] % src_tiling[0] == 0 && + (!src_offsets[0].has_value() || (packing == 1 && tiled_ishape[0] == 1)) && + // This relayout relies on gathers, which are cheap on newer generations, + // so we always use it for them. + // TODO(tlongeri): Once we have it, probably also prefer the + // small-to-large rotate+blend relayout if we don't need replication. It's + // slightly cheaper for some dst vregs you rotate by 0. + // TODO(tlongeri): Using store + multiple replicated loads is good on + // older gens. I wonder if we can integrate this logic to scratch retiling + (try_replicate_rows || ctx.hardware_generation >= 5)) { const LayoutOffset dst_minor_offset = - src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1]) - : std::nullopt; + src.offsets()[1].has_value() ? *src.offsets()[1] % dst_vreg_slice[1] + : LayoutOffset(); const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset}, dst_tiling, src.implicit_dim()); - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - retiled.Each([&](absl::Span idx, Value *tile) { - SmallVector src_idx(idx.begin(), idx.end()); - *(src_idx.end() - 2) *= target_shape[0]; - if (!src.offsets()[1].has_value()) { - // With (1, 128) tiling each vreg holds values from a single row. This - // means that if the columns are replicated, then the whole vreg is - // already replicated. - *(src_idx.end() - 1) = 0; - *tile = vregs(src_idx); - } else { - // The column (in units of sublanes) of the sublane we want: - const int64_t sublane_column = - *(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1]; - *(src_idx.end() - 1) = sublane_column / target_shape[0]; - const int64_t src_sl_idx = sublane_column % target_shape[0]; - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + const SmallVector dst_vreg_array_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); + const int64_t src_tiles_per_vreg = src.tilesPerVreg(ctx.target_shape); + const int64_t dst_tiles_per_vreg = dst.tilesPerVreg(ctx.target_shape); + const int64_t src_sublanes_per_tile = src.sublanesPerTile(ctx.target_shape); + const int64_t dst_sublanes_per_tile = dst.sublanesPerTile(ctx.target_shape); + xla::Array retiled(dst_vreg_array_shape); + SmallVector idxs; + retiled.Each([&](absl::Span dst_idx, Value *vreg) { + const int64_t dst_col_idx = *(dst_idx.end() - 1); + const int64_t base_dst_tile_idx = dst_col_idx * dst_tiles_per_vreg; + const int64_t base_src_tile_idx = + src_offsets[1].has_value() + ? base_dst_tile_idx + + (*src_offsets[1] - *dst_minor_offset) / src_tiling[1] + : 0; + // The following should be true from our choice of minor offset: + DCHECK_EQ(base_src_tile_idx % dst_tiles_per_vreg, 0); + const int64_t src_col_idx = base_src_tile_idx / src_tiles_per_vreg; + SmallVector gather_pattern; + // Iterate over the sublanes in the dst vreg: + for (int32_t sublane = 0; sublane < ctx.target_shape[0]; ++sublane) { + const int64_t dst_tile_idx_in_vreg = sublane / dst_sublanes_per_tile; + const int64_t src_tile_idx_in_vreg = + base_src_tile_idx % src_tiles_per_vreg + dst_tile_idx_in_vreg; + // Although replication may give us several sublanes to choose from, + // we always gather from the first sublane in the source tile. This + // degenerates to a broadcast when dst_tiling is native, which can + // be cheaper than an arbitrary gather (for some hardware gens). + const int64_t src_sublane_in_tile = + src_offsets[0].value_or(0) / packing; + const int64_t src_sublane = + src_tile_idx_in_vreg * src_sublanes_per_tile + src_sublane_in_tile; + gather_pattern.push_back(src_sublane); } + idxs.assign(dst_idx.begin(), dst_idx.end()); + *(idxs.end() - 2) = 0; + *(idxs.end() - 1) = src_col_idx; + Value src_vreg = vregs(idxs); + *vreg = builder.create(loc, src_vreg.getType(), src_vreg, + gather_pattern, + /*dimension=*/0); }); - // We have successfully replicated sublanes return std::pair(dst, std::move(retiled)); } VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, @@ -6576,8 +6617,11 @@ FailureOr> relayout(RewriteContext &ctx, return assemble_with_mask_check(src_tiles, /*use_implicit_shape=*/true); } - if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && - !src.offsets()[1].has_value()) { + + if (const LayoutOffsets src_offsets = + src.getCanonicalOffsets(vty.getShape(), ctx.target_shape); + src.layout_rank() >= dst.layout_rank() && !src_offsets[0].has_value() && + !src_offsets[1].has_value()) { // A fully replicated value is always easy to relayout xla::Array dst_tiles( dst.tileArrayImplicitShape(vty.getShape(), target_shape));