diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 7018af701cc22..f3669c4de9522 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -261,6 +261,8 @@ xla_test( ":triton_test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 4a4245c159067..a105789fa86f0 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -928,6 +928,56 @@ Value EmitTiledBroadcast( padded_output_tile_shape); } +absl::StatusOr EmitTiledIota(ImplicitLocOpBuilder& b, + ValueRange tile_multi_index, + const TiledHloInstruction& tiled_iota) { + const HloIotaInstruction* hlo_iota = + ::xla::Cast(tiled_iota.hlo()); + int64_t iota_dim = hlo_iota->iota_dimension(); + + SmallVector padded_tile_sizes = + GetPaddedTileSizes(tiled_iota.tile_sizes()); + + // We can treat iota more or less as a parameter load, except that we need to + // generate the right values in the right place as opposed to loading them. + TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_iota.tile_offsets_indexing()); + + auto iota_dim_offset = b.create( + b.getI32Type(), mlir_converter::ApplyIndexing( + tile_offsets_indexing, /*dims=*/tile_multi_index, + /*symbols=*/{}, b)[iota_dim]); + + // First, stride as needed between the iota components. + Value range = b.create( + Range(b, padded_tile_sizes[iota_dim]), + Splat(b, + CreateConst(b, b.getI32Type(), tiled_iota.tile_strides()[iota_dim]), + padded_tile_sizes[iota_dim])); + + // Then, add the base offset to the iota components. + range = b.create( + range, Splat(b, iota_dim_offset, padded_tile_sizes[iota_dim])); + + // Cast the result to the targeted type. + TF_ASSIGN_OR_RETURN(Type iota_element_type, + TritonType(b, hlo_iota->shape().element_type())); + + range = Cast(b, range, iota_element_type); + + // And finally, produce a broadcast along the non-iota dimensions in order to + // produce the whole iota tile. + for (int i = 0; i < padded_tile_sizes.size() - 1; i++) { + if (i < iota_dim) { + range = b.create(range, /*axis=*/0); + } else { + range = b.create(range, /*axis=*/i + 1); + } + } + + return Broadcast(b, mlir::cast(range), padded_tile_sizes); +} + Value EmitTiledReshape(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, Value input) { SmallVector padded_tile_sizes = GetPaddedTileSizes(tile_sizes); @@ -1057,6 +1107,10 @@ absl::StatusOr EmitTiledHloInstruction( absl::StrCat("Unsupported non-scalar constant ", hlo->ToString())); } + if (hlo->opcode() == HloOpcode::kIota) { + return EmitTiledIota(b, tile_multi_index, tiled_hlo); + } + if (hlo->opcode() == HloOpcode::kBroadcast) { return EmitTiledBroadcast(b, tiled_hlo, values); } diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index f136f7190d1a6..d79a54f461568 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -1196,6 +1199,72 @@ ENTRY main { RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } +TEST_F(TritonEmitterTest, StridedIota4DIsCodegeneratedCorrectly) { + constexpr std::string_view kHloText = R"( +triton_computation { + iota = f32[3,4,1000,5] iota(), iota_dimension=2 + ROOT slice = f32[3,4,182,5] slice(iota), slice={[0:3], [0:4], [91:1000:5], [0:5]} +} + +ENTRY main { + ROOT triton_fusion = f32[3,4,182,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})"; + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.muli{{.*}} %[[RANGE]] +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +class IotaEmitterParametrizedTest + : public TritonEmitterTest, + public ::testing::WithParamInterface {}; + +TEST_P(IotaEmitterParametrizedTest, Iota4DIsCodegeneratedCorrectly) { + auto data_type = GetParam(); + const std::string kHloText = + absl::Substitute(R"( +triton_computation { + ROOT iota = $0[3,4,1000,5] iota(), iota_dimension=2 +} + +ENTRY main { + ROOT triton_fusion = $0[3,4,1000,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})", + primitive_util::LowercasePrimitiveTypeName(data_type)); + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.addi{{.*}} %[[RANGE]] + // Omit the data type below, since it depends on a test parameter + // and is not abbreviated the same as in HLO. +CHECK: tt.broadcast {{.*}} -> tensor<1x2x64x8x +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite, + IotaEmitterParametrizedTest, + ::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32, + F64})); + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/triton/triton_support.cc b/xla/service/gpu/fusions/triton/triton_support.cc index d0a33343fa223..d006ae65fcc55 100644 --- a/xla/service/gpu/fusions/triton/triton_support.cc +++ b/xla/service/gpu/fusions/triton/triton_support.cc @@ -285,6 +285,15 @@ CodegenDecision IsTritonSupportedInstructionImpl( "Only scalar constants are supported in Triton."); } + if (instr.opcode() == HloOpcode::kIota) { + PrimitiveType element_type = instr.shape().element_type(); + return element_type != PrimitiveType::F8E4M3FN && + element_type != PrimitiveType::F8E5M2 + ? CodegenDecision::Allow() + : CodegenDecision::Forbid( + "F8E4M3FN and F8E5M2 are not supported for iota."); + } + if (instr.IsElementwise()) { if (!IsTritonSupportedElementwise( instr.opcode(), diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index ff7807077d66c..01ca60c0da3e2 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -780,6 +780,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index 9b9e637904210..a08b8aaa3ea90 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -345,7 +345,7 @@ void SortTiledHloInstructionsInPostOrder( }); } -} // namespace +} // anonymous namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( const HloComputation& computation, MLIRContext* ctx, @@ -562,7 +562,8 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( std::optional tile_offset_indexing; if (compute_all_tile_offset_indexing_maps || - parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo())) { + parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo()) || + symbolic_tiled_hlo->hlo()->opcode() == HloOpcode::kIota) { TF_ASSIGN_OR_RETURN( tile_offset_indexing, ComputeTileOffsetIndexing( diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h index 58a08afde9ba1..775de1670f51e 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -91,7 +91,7 @@ class SymbolicTileAnalysis { // Returns a graph of HLO instructions tiled with the given tile parameters. // The provided tile parameters must satisfy the analysis's constraints. - // By default, `ComputetiledHloInstructions` performs a check that the + // By default, `ComputeTiledHloInstructions` performs a check that the // constraints are satisfied by the chosen tiled parameters. Setting // `constraints_are_known_satisfied` to true bypasses this check. // diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 99f136d3aecae..f4166076234a7 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1020,6 +1021,29 @@ ENTRY main { EXPECT_TRUE(analysis.has_value()); } +TEST_F(SymbolicTileAnalysisTest, IotaAlwaysHasTileOffsetsIndexingSet) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + ROOT iota = s32[100] iota(), iota_dimension=0 +} + +ENTRY main { + ROOT fusion = s32[100] fusion(), kind=kLoop, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{4}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/false)); + + const TiledHloInstruction* iota = tiled_hlo_computation.GetRoot(); + EXPECT_THAT(iota->tile_offsets_indexing().status(), ::tsl::testing::IsOk()); +} + } // namespace } // namespace gpu } // namespace xla