Skip to content

Commit

Permalink
[XLA:GPU] Add support for iota in the Triton fusion emitter.
Browse files Browse the repository at this point in the history
`Iota` must be treated like a parameter, i.e. it needs to be offset, and
potentially strided. We therefore need to ensure that`tile_offsets_indexing`
is always derived for the instruction.

PiperOrigin-RevId: 679023133
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Sep 26, 2024
1 parent a769695 commit 843b668
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 3 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ xla_test(
":triton_test_utils",
"//xla:autotuning_proto_cc",
"//xla:error_spec",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:backend_configs_cc",
Expand Down
54 changes: 54 additions & 0 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,56 @@ Value EmitTiledBroadcast(
padded_output_tile_shape);
}

absl::StatusOr<Value> EmitTiledIota(ImplicitLocOpBuilder& b,
ValueRange tile_multi_index,
const TiledHloInstruction& tiled_iota) {
const HloIotaInstruction* hlo_iota =
::xla::Cast<HloIotaInstruction>(tiled_iota.hlo());
int64_t iota_dim = hlo_iota->iota_dimension();

SmallVector<int64_t> padded_tile_sizes =
GetPaddedTileSizes(tiled_iota.tile_sizes());

// We can treat iota more or less as a parameter load, except that we need to
// generate the right values in the right place as opposed to loading them.
TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing,
tiled_iota.tile_offsets_indexing());

auto iota_dim_offset = b.create<ma::IndexCastUIOp>(
b.getI32Type(), mlir_converter::ApplyIndexing(
tile_offsets_indexing, /*dims=*/tile_multi_index,
/*symbols=*/{}, b)[iota_dim]);

// First, stride as needed between the iota components.
Value range = b.create<ma::MulIOp>(
Range(b, padded_tile_sizes[iota_dim]),
Splat(b,
CreateConst(b, b.getI32Type(), tiled_iota.tile_strides()[iota_dim]),
padded_tile_sizes[iota_dim]));

// Then, add the base offset to the iota components.
range = b.create<ma::AddIOp>(
range, Splat(b, iota_dim_offset, padded_tile_sizes[iota_dim]));

// Cast the result to the targeted type.
TF_ASSIGN_OR_RETURN(Type iota_element_type,
TritonType(b, hlo_iota->shape().element_type()));

range = Cast(b, range, iota_element_type);

// And finally, produce a broadcast along the non-iota dimensions in order to
// produce the whole iota tile.
for (int i = 0; i < padded_tile_sizes.size() - 1; i++) {
if (i < iota_dim) {
range = b.create<mt::ExpandDimsOp>(range, /*axis=*/0);
} else {
range = b.create<mt::ExpandDimsOp>(range, /*axis=*/i + 1);
}
}

return Broadcast(b, mlir::cast<TensorValue>(range), padded_tile_sizes);
}

Value EmitTiledReshape(ImplicitLocOpBuilder& b, ArrayRef<int64_t> tile_sizes,
Value input) {
SmallVector<int64_t> padded_tile_sizes = GetPaddedTileSizes(tile_sizes);
Expand Down Expand Up @@ -1057,6 +1107,10 @@ absl::StatusOr<Value> EmitTiledHloInstruction(
absl::StrCat("Unsupported non-scalar constant ", hlo->ToString()));
}

if (hlo->opcode() == HloOpcode::kIota) {
return EmitTiledIota(b, tile_multi_index, tiled_hlo);
}

if (hlo->opcode() == HloOpcode::kBroadcast) {
return EmitTiledBroadcast(b, tiled_hlo, values);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "llvm/IR/LLVMContext.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -30,6 +31,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h"
#include "xla/service/gpu/fusions/triton/triton_test_utils.h"
Expand All @@ -40,6 +42,7 @@ limitations under the License.
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
Expand Down Expand Up @@ -1196,6 +1199,72 @@ ENTRY main {
RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

TEST_F(TritonEmitterTest, StridedIota4DIsCodegeneratedCorrectly) {
constexpr std::string_view kHloText = R"(
triton_computation {
iota = f32[3,4,1000,5] iota(), iota_dimension=2
ROOT slice = f32[3,4,182,5] slice(iota), slice={[0:3], [0:4], [91:1000:5], [0:5]}
}
ENTRY main {
ROOT triton_fusion = f32[3,4,182,5] fusion(),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"],
"num_warps":"1"}}}
})";

TF_EXPECT_OK(
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32>
CHECK: arith.muli{{.*}} %[[RANGE]]
)"));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

class IotaEmitterParametrizedTest
: public TritonEmitterTest,
public ::testing::WithParamInterface<PrimitiveType> {};

TEST_P(IotaEmitterParametrizedTest, Iota4DIsCodegeneratedCorrectly) {
auto data_type = GetParam();
const std::string kHloText =
absl::Substitute(R"(
triton_computation {
ROOT iota = $0[3,4,1000,5] iota(), iota_dimension=2
}
ENTRY main {
ROOT triton_fusion = $0[3,4,1000,5] fusion(),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton",
"block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"],
"num_warps":"1"}}}
})",
primitive_util::LowercasePrimitiveTypeName(data_type));

TF_EXPECT_OK(
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32>
CHECK: arith.addi{{.*}} %[[RANGE]]
// Omit the data type below, since it depends on a test parameter
// and is not abbreviated the same as in HLO.
CHECK: tt.broadcast {{.*}} -> tensor<1x2x64x8x
)"));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite,
IotaEmitterParametrizedTest,
::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32,
F64}));

} // namespace
} // namespace gpu
} // namespace xla
9 changes: 9 additions & 0 deletions xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ CodegenDecision IsTritonSupportedInstructionImpl(
"Only scalar constants are supported in Triton.");
}

if (instr.opcode() == HloOpcode::kIota) {
PrimitiveType element_type = instr.shape().element_type();
return element_type != PrimitiveType::F8E4M3FN &&
element_type != PrimitiveType::F8E5M2
? CodegenDecision::Allow()
: CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for iota.");
}

if (instr.IsElementwise()) {
if (!IsTritonSupportedElementwise(
instr.opcode(),
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ xla_cc_test(
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
],
)
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/model/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ void SortTiledHloInstructionsInPostOrder(
});
}

} // namespace
} // anonymous namespace

/*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation(
const HloComputation& computation, MLIRContext* ctx,
Expand Down Expand Up @@ -562,7 +562,8 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions(

std::optional<IndexingMap> tile_offset_indexing;
if (compute_all_tile_offset_indexing_maps ||
parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo())) {
parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo()) ||
symbolic_tiled_hlo->hlo()->opcode() == HloOpcode::kIota) {
TF_ASSIGN_OR_RETURN(
tile_offset_indexing,
ComputeTileOffsetIndexing(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/model/symbolic_tile_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SymbolicTileAnalysis {

// Returns a graph of HLO instructions tiled with the given tile parameters.
// The provided tile parameters must satisfy the analysis's constraints.
// By default, `ComputetiledHloInstructions` performs a check that the
// By default, `ComputeTiledHloInstructions` performs a check that the
// constraints are satisfied by the chosen tiled parameters. Setting
// `constraints_are_known_satisfied` to true bypasses this check.
//
Expand Down
24 changes: 24 additions & 0 deletions xla/service/gpu/model/symbolic_tile_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/util.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down Expand Up @@ -1020,6 +1021,29 @@ ENTRY main {
EXPECT_TRUE(analysis.has_value());
}

TEST_F(SymbolicTileAnalysisTest, IotaAlwaysHasTileOffsetsIndexingSet) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
fusion {
ROOT iota = s32[100] iota(), iota_dimension=0
}
ENTRY main {
ROOT fusion = s32[100] fusion(), kind=kLoop, calls=fusion
})"));
std::optional<SymbolicTileAnalysis> analysis = TryAnalyzeModule(module.get());
ASSERT_TRUE(analysis.has_value());

TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation,
analysis->ComputeTiledHloInstructions(
/*tile_parameters=*/{4},
/*constraints_are_known_satisfied=*/false,
/*compute_all_tile_offset_indexing_maps=*/false));

const TiledHloInstruction* iota = tiled_hlo_computation.GetRoot();
EXPECT_THAT(iota->tile_offsets_indexing().status(), ::tsl::testing::IsOk());
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 843b668

Please sign in to comment.