Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Remove is_simplified flag.
Browse files Browse the repository at this point in the history
The benchmarks don't show a lot of improvements in compile time.

PiperOrigin-RevId: 679693188
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 27, 2024
1 parent 2cc2ff2 commit 0983168
Show file tree
Hide file tree
Showing 33 changed files with 570 additions and 986 deletions.
30 changes: 8 additions & 22 deletions xla/service/gpu/fusions/ir/tests/attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
// CHECK-SAME: d2 in [10, 12],
// CHECK-SAME: s0 in [0, 32],
// CHECK-SAME: d0 + s0 in [1, 10],
// CHECK-SAME: d0 mod 2 in [0, 1],
// CHECK-SAME: is_simplified: true"
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
Expand All @@ -18,8 +17,7 @@
"d2 in [10, 12],"
"s0 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"is_simplified: true"
"d0 + s0 in [1, 10]"
>

func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>)
Expand All @@ -39,7 +37,6 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>
// CHECK-SAME: d0 + s0 in [1, 10]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: d1 + s1 + s2 in [1, 32]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<
"(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),"
Expand All @@ -51,8 +48,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>
"s2 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"d1 + s1 + s2 in [1, 32],"
"is_simplified: false"
"d1 + s1 + s2 in [1, 32]"
>
func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-LABEL: @more_range_vars
Expand All @@ -65,13 +61,11 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [0, 100]
// CHECK-SAME: s0 in [-3, -1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0),"
"domain:"
"d0 in [0, 100],"
"s0 in [-3, -1],"
"is_simplified: false"
"s0 in [-3, -1]"
>
func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @indexing_map_small
Expand All @@ -86,15 +80,13 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: d1 in [5, 8]
// CHECK-SAME: d2 in [10, 12]
// CHECK-SAME: s0 in [0, 32]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
"d0 in [1, 2],"
"d1 in [5, 8],"
"d2 in [10, 12],"
"s0 in [0, 32],"
"is_simplified: false"
"s0 in [0, 32]"
>
func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-LABEL: @no_constraints
Expand All @@ -107,13 +99,11 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: s0 in [3, 5]
// CHECK-SAME: s0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"()[s0] -> (s0),"
"domain:"
"s0 in [3, 5],"
"s0 mod 2 in [0, 1],"
"is_simplified: false"
"s0 mod 2 in [0, 1]"
>
func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_dimensions
Expand All @@ -126,13 +116,11 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [3, 5]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0) -> (d0),"
"domain:"
"d0 in [3, 5],"
"d0 mod 2 in [0, 1],"
"is_simplified: false"
>
func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_symbols
Expand All @@ -152,8 +140,6 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)

func.func private @tensor_layout(
%in0: tensor<42xf32, #xla_gpu.layout<"shmem",
"(d0) -> (),"
"domain: d0 in [0, 42], is_simplified: true">>)
// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (),
// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true">
"(d0) -> ()," "domain: d0 in [0, 42]">>)
// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), domain:
// CHECK: tensor<42xf32, #layout>
60 changes: 23 additions & 37 deletions xla/service/gpu/fusions/ir/tests/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s

#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2], is_simplified: false">
#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]">
func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
%0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1]
func.return %0#0, %0#1 : index, index
}
// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2),
// CHECK-SAME: domain: d0 in [-10, 10]
// CHECK-SAME: is_simplified: true">
// CHECK-SAME: domain: d0 in [-10, 10]">

// CHECK-LABEL: func.func @simplify_apply_indexing
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]])

// -----

#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false">
#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]">
func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
%d2: index, %s0: index, %s1: index) -> (index, index, index) {
%0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1]
Expand All @@ -35,16 +34,7 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,

// -----

#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true">
func.func @do_not_simplify_if_is_simplified_is_true(%d0: index) -> (index) {
%0 = xla_gpu.apply_indexing #map0(%d0)
func.return %0 : index
}
// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 10)

// -----

#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false">
#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]">
func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
-> (index, index, index, index, index) {
%0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
Expand All @@ -64,7 +54,7 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
// -----

#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0),"
"domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false">
"domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]">
func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) {
%0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
func.return %0#2 : index
Expand All @@ -81,8 +71,7 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index)
// -----

#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3),"
"domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4],"
"is_simplified: false">
"domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]">
func.func @fold_operands(%d0: index) -> index {
%d1 = arith.constant 1 : index
%s0 = arith.constant 2 : index
Expand All @@ -102,7 +91,7 @@ func.func @fold_operands(%d0: index) -> index {
func.func @fold_operands_and_results(%arg0: index, %arg1: index)
-> (index, index) {
%0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1),"
"domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false">(%arg0, %arg1)
"domain: d0 in [0, 4], d1 in [0, 5]">(%arg0, %arg1)
return %0#0, %0#1 : index, index
}

Expand All @@ -115,10 +104,9 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index)

func.func @fold_sequence(%arg0: index, %arg1: index) -> index {
%0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<
"(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4],"
"is_simplified: false">(%arg0, %arg1)
"(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1)
%1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42),"
"domain: d0 in [0, 10000], is_simplified: false">(%0)
"domain: d0 in [0, 10000]">(%0)
func.return %1 : index
}

Expand All @@ -133,10 +121,9 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index {

func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
%0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), "
"domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1)
"domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1)
%1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<
"()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000],"
"is_simplified: false">(%0)
"()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]">(%0)
func.return %1 : index
}

Expand All @@ -150,10 +137,10 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
// -----

#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512),"
"domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false">
"domain: d0 in [0, 1], d1 in [0, 607]">
#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> ("
"((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2),"
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false">
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]">

func.func @fold_sequence_no_simplification_needed(%i: index) -> index {
%thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]}
Expand All @@ -167,11 +154,11 @@ func.func @fold_sequence_no_simplification_needed(%i: index) -> index {
// -----

#indexing_map1 = #xla_gpu.indexing_map<
"(d0) -> (3 * d0), domain: d0 in [0, 9407], is_simplified: false">
"(d0) -> (3 * d0), domain: d0 in [0, 9407]">
#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1),"
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false">
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]">
#indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2),"
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false">
"domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]">

func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) {
%thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]}
Expand All @@ -186,9 +173,9 @@ func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) {

func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index {
%0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1),"
"domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1)
"domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1)
%1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1),"
"domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false">(%arg1, %0)
"domain: d0 in [0, 4], d1 in [0, 10000]">(%arg1, %0)
func.return %1 : index
}

Expand Down Expand Up @@ -234,7 +221,7 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index)
// -----

#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0),"
"domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false">
"domain: d0 in [0, 3], s0 in [0, 2]">
func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index)
-> index {
%0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0]
Expand All @@ -249,10 +236,9 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index)

// // -----

#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3],"
"is_simplified: false">
#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]">
#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1),"
"domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false">
"domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]">
func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
%idx = xla_gpu.apply_indexing #map0(%dim)
%sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) {
Expand All @@ -273,9 +259,9 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim:
// -----

#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0),"
"domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false">
"domain: d0 in [0, 3], s0 in [0, 2]">
#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1),"
"domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false">
"domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]">
func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) {
%0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0]
%sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) {
Expand Down
Loading

0 comments on commit 0983168

Please sign in to comment.