From 541453589617e180d05fad4973884d1941214f96 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 27 Sep 2024 06:58:39 -0700 Subject: [PATCH] [XLA:GPU][IndexAnalysis] Remove is_simplified flag. The benchmarks don't show a lot of improvements in compile time. PiperOrigin-RevId: 679573820 --- xla/service/gpu/fusions/ir/tests/attrs.mlir | 30 +- .../gpu/fusions/ir/tests/canonicalize.mlir | 60 +-- xla/service/gpu/fusions/ir/tests/invalid.mlir | 74 +-- xla/service/gpu/fusions/ir/tests/ops.mlir | 20 +- xla/service/gpu/fusions/ir/xla_gpu_attrs.cc | 14 +- xla/service/gpu/fusions/ir/xla_gpu_attrs.td | 5 +- xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 15 +- .../gpu/fusions/legacy/concatenate_test.cc | 3 +- .../in_place_dynamic_update_slice_test.cc | 3 +- .../gpu/fusions/legacy/input_slices_test.cc | 3 +- xla/service/gpu/fusions/legacy/loop_test.cc | 15 +- .../gpu/fusions/legacy/reduction_test.cc | 6 +- .../gpu/fusions/legacy/scatter_test.cc | 6 +- .../gpu/fusions/legacy/transpose_test.cc | 24 +- .../mlir/elemental_hlo_to_mlir_test.cc | 60 +-- .../transforms/tests/flatten_tensors.mlir | 10 +- .../fusions/transforms/tests/fuse_loops.mlir | 36 +- .../tests/lower_xla_gpu_loops_to_scf.mlir | 6 +- .../tests/lower_xla_gpu_to_scf.mlir | 22 +- .../transforms/tests/optimize_loops.mlir | 7 +- .../fusions/transforms/tests/peel_loops.mlir | 14 +- .../transforms/tests/simplify_affine.mlir | 8 +- .../transforms/tests/simplify_arith.mlir | 9 +- .../tests/vectorize_loads_stores.mlir | 37 +- .../triton_fusion_emitter_device_test.cc | 14 +- xla/service/gpu/model/indexing_analysis.cc | 8 +- .../gpu/model/indexing_analysis_test.cc | 489 ++++++------------ xla/service/gpu/model/indexing_map.cc | 23 +- xla/service/gpu/model/indexing_map.h | 12 +- .../gpu/model/indexing_map_serialization.cc | 53 +- .../model/indexing_map_serialization_test.cc | 21 +- xla/service/gpu/model/indexing_map_test.cc | 404 +++++---------- .../gpu/model/symbolic_tile_analysis_test.cc | 45 +- 33 files changed, 570 insertions(+), 986 deletions(-) diff --git a/xla/service/gpu/fusions/ir/tests/attrs.mlir b/xla/service/gpu/fusions/ir/tests/attrs.mlir index b990103ea2cfab..6a199f5f024241 100644 --- a/xla/service/gpu/fusions/ir/tests/attrs.mlir +++ b/xla/service/gpu/fusions/ir/tests/attrs.mlir @@ -8,8 +8,7 @@ // CHECK-SAME: d2 in [10, 12], // CHECK-SAME: s0 in [0, 32], // CHECK-SAME: d0 + s0 in [1, 10], -// CHECK-SAME: d0 mod 2 in [0, 1], -// CHECK-SAME: is_simplified: true" +// CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" @@ -18,8 +17,7 @@ "d2 in [10, 12]," "s0 in [0, 32]," "d0 mod 2 in [0, 1]," - "d0 + s0 in [1, 10]," - "is_simplified: true" + "d0 + s0 in [1, 10]" > func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>) @@ -39,7 +37,6 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // CHECK-SAME: d0 + s0 in [1, 10] // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: d1 + s1 + s2 in [1, 32] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map< "(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)," @@ -51,8 +48,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> "s2 in [0, 32]," "d0 mod 2 in [0, 1]," "d0 + s0 in [1, 10]," - "d1 + s1 + s2 in [1, 32]," - "is_simplified: false" + "d1 + s1 + s2 in [1, 32]" > func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-LABEL: @more_range_vars @@ -65,13 +61,11 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 100] // CHECK-SAME: s0 in [-3, -1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0)," "domain:" "d0 in [0, 100]," - "s0 in [-3, -1]," - "is_simplified: false" + "s0 in [-3, -1]" > func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @indexing_map_small @@ -86,15 +80,13 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: d1 in [5, 8] // CHECK-SAME: d2 in [10, 12] // CHECK-SAME: s0 in [0, 32] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" "d0 in [1, 2]," "d1 in [5, 8]," "d2 in [10, 12]," - "s0 in [0, 32]," - "is_simplified: false" + "s0 in [0, 32]" > func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-LABEL: @no_constraints @@ -107,13 +99,11 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: s0 in [3, 5] // CHECK-SAME: s0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"()[s0] -> (s0)," "domain:" "s0 in [3, 5]," - "s0 mod 2 in [0, 1]," - "is_simplified: false" + "s0 mod 2 in [0, 1]" > func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_dimensions @@ -126,13 +116,11 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [3, 5] // CHECK-SAME: d0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false" // CHECK-SAME: > #map = #xla_gpu.indexing_map<"(d0) -> (d0)," "domain:" "d0 in [3, 5]," "d0 mod 2 in [0, 1]," - "is_simplified: false" > func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_symbols @@ -152,8 +140,6 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) func.func private @tensor_layout( %in0: tensor<42xf32, #xla_gpu.layout<"shmem", - "(d0) -> ()," - "domain: d0 in [0, 42], is_simplified: true">>) -// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), -// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true"> + "(d0) -> ()," "domain: d0 in [0, 42]">>) +// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), domain: // CHECK: tensor<42xf32, #layout> diff --git a/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/xla/service/gpu/fusions/ir/tests/canonicalize.mlir index bfca90e5c64f53..08086e34f60b05 100644 --- a/xla/service/gpu/fusions/ir/tests/canonicalize.mlir +++ b/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -1,13 +1,12 @@ // RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]"> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), -// CHECK-SAME: domain: d0 in [-10, 10] -// CHECK-SAME: is_simplified: true"> +// CHECK-SAME: domain: d0 in [-10, 10]"> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) @@ -15,7 +14,7 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]"> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] @@ -35,16 +34,7 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true"> -func.func @do_not_simplify_if_is_simplified_is_true(%d0: index) -> (index) { - %0 = xla_gpu.apply_indexing #map0(%d0) - func.return %0 : index -} -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 10) - -// ----- - -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -64,7 +54,7 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," - "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false"> + "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index @@ -81,8 +71,7 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," - "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]," - "is_simplified: false"> + "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]"> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index @@ -102,7 +91,7 @@ func.func @fold_operands(%d0: index) -> index { func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1)," - "domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 4], d1 in [0, 5]">(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -115,10 +104,9 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) func.func @fold_sequence(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< - "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]," - "is_simplified: false">(%arg0, %arg1) + "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42)," - "domain: d0 in [0, 10000], is_simplified: false">(%0) + "domain: d0 in [0, 10000]">(%0) func.return %1 : index } @@ -133,10 +121,9 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), " - "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map< - "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]," - "is_simplified: false">(%0) + "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]">(%0) func.return %1 : index } @@ -150,10 +137,10 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // ----- #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," - "domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false"> + "domain: d0 in [0, 1], d1 in [0, 607]"> #indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (" "((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @fold_sequence_no_simplification_needed(%i: index) -> index { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -167,11 +154,11 @@ func.func @fold_sequence_no_simplification_needed(%i: index) -> index { // ----- #indexing_map1 = #xla_gpu.indexing_map< - "(d0) -> (3 * d0), domain: d0 in [0, 9407], is_simplified: false"> + "(d0) -> (3 * d0), domain: d0 in [0, 9407]"> #indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> #indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," - "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false"> + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -186,9 +173,9 @@ func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," - "domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false">(%arg0, %arg1) + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," - "domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false">(%arg1, %0) + "domain: d0 in [0, 4], d1 in [0, 10000]">(%arg1, %0) func.return %1 : index } @@ -234,7 +221,7 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) // ----- #map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," - "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 2]"> func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) -> index { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] @@ -249,10 +236,9 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) // // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]," - "is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," - "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %idx = xla_gpu.apply_indexing #map0(%dim) %sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -273,9 +259,9 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: // ----- #map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," - "domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 2]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," - "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] %sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { diff --git a/xla/service/gpu/fusions/ir/tests/invalid.mlir b/xla/service/gpu/fusions/ir/tests/invalid.mlir index 3c50b5afcd8068..35064858b23150 100644 --- a/xla/service/gpu/fusions/ir/tests/invalid.mlir +++ b/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -1,6 +1,6 @@ // RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} %0:2 = xla_gpu.apply_indexing #map0 (%d0) @@ -9,7 +9,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10], is_simplified: false"> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10]"> func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{apply indexing op cannot have any constraints}} %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -18,7 +18,7 @@ func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of loop-carried values and results}} @@ -36,7 +36,7 @@ func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024]"> func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} @@ -54,7 +54,7 @@ func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} @@ -72,7 +72,7 @@ func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} @@ -87,7 +87,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { @@ -99,8 +99,8 @@ func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -112,8 +112,8 @@ func.func @no_thread_id_in(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -125,8 +125,8 @@ func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -135,8 +135,8 @@ func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: inde // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -149,8 +149,8 @@ func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{number of symbols in both indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -159,8 +159,8 @@ func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, % // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{domain of symbols of indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -169,8 +169,8 @@ func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -182,8 +182,8 @@ func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -195,8 +195,8 @@ func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -209,8 +209,8 @@ func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]"> func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -222,8 +222,8 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -236,8 +236,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -250,8 +250,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -264,8 +264,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -277,8 +277,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { diff --git a/xla/service/gpu/fusions/ir/tests/ops.mlir b/xla/service/gpu/fusions/ir/tests/ops.mlir index 81e08968db7590..f6fd03d8f1ed24 100644 --- a/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -57,7 +57,7 @@ func.func @caller(%a: f32, %b: f32) -> f32 { // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," - "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], is_simplified: false"> + "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index @@ -78,7 +78,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- #map0 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," - "domain: d0 in [0, 2], d1 in [1, 3], is_simplified: false"> + "domain: d0 in [0, 2], d1 in [1, 3]"> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) func.return %0#0, %0#1 : index, index @@ -98,7 +98,7 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // ----- #map0 = #xla_gpu.indexing_map<"()[s0] -> (s0, s0)," - "domain: s0 in [2, 4], is_simplified: false"> + "domain: s0 in [2, 4]"> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0] func.return %0#0, %0#1 : index, index @@ -116,7 +116,7 @@ func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), " - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -141,11 +141,11 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 #map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," - "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> #map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," - "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> #map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," - "domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> + "domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -161,7 +161,7 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] // CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1) -// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], +// CHECK-SAME: d0 in [0, 32], d1 in [0, 2]"> // CHECK-LABEL: @materialize_and_insert // CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at // CHECK-SAME: #[[$MAP]](%{{.*}}, %{{.*}}) @@ -216,7 +216,7 @@ func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32) // ----- #map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," - "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { %0 = xla_gpu.reindex %in0 at #map : tensor<1024xf32> -> tensor<16x64xf32> func.return %0 : tensor<16x64xf32> @@ -231,7 +231,7 @@ func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { // ----- #map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," - "domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false"> + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reindex %in0 at #map default %c0 diff --git a/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 8a0380b0706f75..535a24fd55788a 100644 --- a/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -48,8 +48,7 @@ using mlir::success; // Parses a chain of string attributes into an indexing map. // Example: // "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," -// " domain: s0 in [-10, 10], s1 in [0, 2]," -// " is_simplified: false" +// " domain: s0 in [-10, 10], s1 in [0, 2]" // will be parsed as 3 StringAttrs, concatenated into a single string, and then // parsed into an IndexingMap. std::optional parseChainOfStringsAsIndexingMap( @@ -84,17 +83,16 @@ IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, constraints.push_back({constraint.first, constraint.second}); } return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), - indexing_map.GetRangeVars(), constraints, - indexing_map.IsSimplified()); + indexing_map.GetRangeVars(), constraints); } mlir::LogicalResult IndexingMapAttr::verify( mlir::function_ref emitError, mlir::AffineMap map, ArrayRef dim_vars, ArrayRef range_vars, - ArrayRef> constraints, bool is_simplified) { - auto indexing_map = IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, - constraints, is_simplified); + ArrayRef> constraints) { + auto indexing_map = + IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints); std::stringstream ss; if (!indexing_map.Verify(ss)) { return emitError() << ss.str(); @@ -104,7 +102,7 @@ mlir::LogicalResult IndexingMapAttr::verify( IndexingMap IndexingMapAttr::getIndexingMap() const { return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, - getConstraints(), getIsSimplified()); + getConstraints()); } int64_t IndexingMapAttr::getNumResults() const { diff --git a/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 44e8dd4353a5b6..f42a2254558724 100644 --- a/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -36,8 +36,6 @@ def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar", "RangeVarArray"> { } -def XLAGPU_BoolParameter : AttrOrTypeParameter<"bool", ""> {} - def XLAGPU_ConstraintsParameter : ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>", "ContraintsArray"> { @@ -52,8 +50,7 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { let parameters = (ins XLAGPU_AffineMapParameter:$map, XLAGPU_DimVarsParameter:$dim_vars, XLAGPU_RangeVarsParameter:$range_vars, - XLAGPU_ConstraintsParameter:$constraints, - XLAGPU_BoolParameter:$is_simplified); + XLAGPU_ConstraintsParameter:$constraints); let hasCustomAssemblyFormat = 1; let builders = [ AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>, diff --git a/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index a4724eb8b5c9f6..2aa00180e326b1 100644 --- a/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -183,13 +183,13 @@ ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, parser.parseOptionalAttrDict(result.attributes)) { return failure(); } - auto map = indexing_map_attr.getMap(); + auto map = indexing_map_attr.getIndexingMap().GetAffineMap(); result.addTypes(SmallVector(map.getNumResults(), index_type)); return success(); } void ApplyIndexingOp::print(OpAsmPrinter& p) { - AffineMap affine_map = getIndexingMapAttr().getMap(); + AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); p << " " << getIndexingMapAttr(); auto operands = getOperands(); @@ -214,14 +214,14 @@ void ApplyIndexingOp::print(OpAsmPrinter& p) { } LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getIndexingMapAttr().getMap(); + auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); if (getOperands().size() != num_variables) { return emitOpError( "operand count must match the number of dimensions and symbols in the " "affine map"); } - if (!getIndexingMapAttr().getConstraints().empty()) { + if (!getIndexingMap().GetConstraints().empty()) { return emitOpError("apply indexing op cannot have any constraints"); } return success(); @@ -310,11 +310,10 @@ struct SimplifyIndexingMap : public mlir::OpRewritePattern { LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (indexing_map.IsSimplified()) { + if (!indexing_map.Simplify()) { return rewriter.notifyMatchFailure(indexing_op, "IndexingMap is already simplified"); } - indexing_map.Simplify(); rewriter.replaceOpWithNewOp( indexing_op, indexing_op.getOperands(), indexing_map); return success(); @@ -1046,12 +1045,12 @@ LogicalResult MaterializeOp::verify() { //===----------------------------------------------------------------------===// LogicalResult InsertOp::verify() { - if (!getMap().getRangeVars().empty()) { + if (!getMap().getIndexingMap().GetRangeVars().empty()) { return emitOpError() << "insert_op map must not have any symbols"; } int64_t vector_map_num_results = getSource().getType().getIndexingMapAttr().getNumResults(); - if (vector_map_num_results != getMap().getDimVars().size()) { + if (vector_map_num_results != getMap().getIndexingMap().GetDimVars().size()) { return emitOpError() << "source map result count must equal insert_op's " "map's dimension count"; } diff --git a/xla/service/gpu/fusions/legacy/concatenate_test.cc b/xla/service/gpu/fusions/legacy/concatenate_test.cc index 32437d5bca3772..ce7da7bcb22485 100644 --- a/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ b/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -82,8 +82,7 @@ TEST_F(ConcatenateTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 399], - is_simplified: true + bl_x * 128 + th_x in [0, 399] )"; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( diff --git a/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index 6bf9ea865e1c45..27d3aa2170be3f 100644 --- a/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ b/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -88,8 +88,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/xla/service/gpu/fusions/legacy/input_slices_test.cc b/xla/service/gpu/fusions/legacy/input_slices_test.cc index 08fcc0d387c777..9de13b8bd7df5c 100644 --- a/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ b/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -88,8 +88,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 29], - is_simplified: true + bl_x * 128 + th_x in [0, 29] )")); } diff --git a/xla/service/gpu/fusions/legacy/loop_test.cc b/xla/service/gpu/fusions/legacy/loop_test.cc index 60ae18e5cc6a17..b23e9b2b19a213 100644 --- a/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/xla/service/gpu/fusions/legacy/loop_test.cc @@ -93,8 +93,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 3], - bl_x * 128 + th_x in [0, 1499999], - is_simplified: true + bl_x * 128 + th_x in [0, 1499999] )")); } @@ -133,8 +132,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -152,8 +150,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); } @@ -196,8 +193,7 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -217,8 +213,7 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); } diff --git a/xla/service/gpu/fusions/legacy/reduction_test.cc b/xla/service/gpu/fusions/legacy/reduction_test.cc index 46c7a26970e538..54fff94a6ed775 100644 --- a/xla/service/gpu/fusions/legacy/reduction_test.cc +++ b/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -86,8 +86,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { s0 in [0, 0], s1 in [0, 0], s2 in [0, 7], - s3 in [0, 1], - is_simplified: true + s3 in [0, 1] )")); EXPECT_THAT( ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)), @@ -103,8 +102,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { d3 in [0, 799], d4 in [0, 0], d5 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: true + d0 mod 32 in [0, 0] )")); } diff --git a/xla/service/gpu/fusions/legacy/scatter_test.cc b/xla/service/gpu/fusions/legacy/scatter_test.cc index 7381d375645660..e7d1d8eae303c9 100644 --- a/xla/service/gpu/fusions/legacy/scatter_test.cc +++ b/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -156,8 +156,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}; @@ -197,8 +196,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { chunk_id in [0, 0], unroll_id in [0, 0], index_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( ToString(*fusion->ComputeThreadIdToInputIndexing( diff --git a/xla/service/gpu/fusions/legacy/transpose_test.cc b/xla/service/gpu/fusions/legacy/transpose_test.cc index c66094061e6366..1e503025d889d3 100644 --- a/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -95,8 +95,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -116,8 +115,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -159,8 +157,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -180,8 +177,7 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -225,8 +221,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + d0 mod 32 in [0, 23] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), @@ -246,8 +241,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + d0 mod 32 in [0, 23] )")); } @@ -322,8 +316,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( ToString(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)), @@ -343,8 +336,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 5c87db0045dac0..522d2653153292 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -234,10 +234,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2], is_simplified: true">(%[[Y]]) + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2]">(%[[Y]]) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 3), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 6], is_simplified: true">(%[[Z]], %[[I]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 6]">(%[[Z]], %[[I]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -285,7 +285,7 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // `d1 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 3], is_simplified: true">(%[[X]], %[[I]]) + // CHECK-SAME: d0 in [0, 18], d1 in [0, 3]">(%[[X]], %[[I]]) // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -505,7 +505,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -517,9 +517,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -547,7 +547,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -559,9 +559,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true">(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true">(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -879,10 +879,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -925,10 +925,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 2], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 2], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -971,21 +971,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 2), - // CHECK-SAME: d0 in [0, 11], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1025,17 +1025,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 12], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1078,10 +1078,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1124,13 +1124,13 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), - // CHECK-SAME: d0 in [0, 15], d1 in [0, 1], is_simplified: true">(%[[O]], %[[I]]) + // CHECK-SAME: d0 in [0, 15], d1 in [0, 1]">(%[[O]], %[[I]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1175,10 +1175,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true">(%[[W]], %[[X]]) + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true">(%[[H]], %[[Y]]) + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1645,7 +1645,7 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1669,7 +1669,7 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true">(%[[X]], %[[Y]]) + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index e88324f698d489..d35dc71ddad023 100644 --- a/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,7 +8,7 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2], is_simplified: true"> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]"> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, @@ -67,7 +67,7 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3], is_simplified: true"> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]"> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { @@ -114,9 +114,9 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // ----- -#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true"> +#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]"> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { diff --git a/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir index 594c8e1deec7d2..1287b8fc3e91a5 100644 --- a/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir +++ b/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -8,8 +8,7 @@ " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -17,8 +16,7 @@ " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -67,8 +65,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -76,8 +73,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -115,8 +111,7 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -124,8 +119,7 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -165,8 +159,7 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -174,8 +167,7 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 5], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -214,8 +206,7 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," @@ -223,8 +214,7 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," -" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index @@ -263,8 +253,7 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> #indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" " (0," " d0 mod 32," @@ -272,8 +261,7 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> " domain:" " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," -" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]," -" is_simplified: true"> +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> %c0 = arith.constant 0 : index diff --git a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir index 427e764d12b914..f981cef83029d8 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -2,8 +2,7 @@ // RUN: --split-input-file | FileCheck %s #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," - "is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) @@ -61,8 +60,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]," - "is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_yields_value_from_above(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { diff --git a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index 347ed9a943ef82..f53ccc1e8ae54f 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -124,8 +124,8 @@ func.func @predicated_extract( func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> { @@ -149,8 +149,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -181,9 +181,9 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -199,8 +199,8 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> func.func @materialize_complex( %input: tensor<32x64xcomplex>, %output: tensor<32x64xcomplex>, @@ -227,8 +227,8 @@ func.func @materialize_complex( // ----- -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert_complex( %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, %output: tensor<32x64xcomplex>, diff --git a/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index 17f478b2838dde..1094b51a2a6841 100644 --- a/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir +++ b/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,7 +1,8 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31], is_simplified: false"> #map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31], is_simplified: false"> -#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7], is_simplified: false"> +#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31]"> +#map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]"> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -150,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15], is_simplified: false">(%i) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> diff --git a/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir index f965b069a772cc..9ffd7bdc0fbfd1 100644 --- a/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir +++ b/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -3,7 +3,7 @@ #map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" "d0 in [0, 3], s0 in [0, 7], s1 in [0, 10], d0 + s0 in [0, 9]," - "d0 + s1 in [0, 12], is_simplified: false"> + "d0 + s1 in [0, 12]"> func.func @peel_both_loops(%input: tensor<16x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -14,9 +14,9 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, } func.return %sum : f32 } -// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9], is_simplified: true"> -// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9], is_simplified: true"> -// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10], is_simplified: true"> +// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]"> // CHECK-LABEL: func.func @peel_both_loops( // CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, @@ -42,7 +42,7 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (s0)," - "domain: d0 in [0, 3], s0 in [0, 7], is_simplified: false"> + "domain: d0 in [0, 3], s0 in [0, 7]"> func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) @@ -64,9 +64,7 @@ func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, " domain:" " d0 in [0, 3]," " s0 in [0, 7]," -" s0 mod 5 in [0, 1]," -" is_simplified: false" -> +" s0 mod 5 in [0, 1]"> func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) diff --git a/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index bfddbd60e2bde7..e62a530de0e7db 100644 --- a/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir +++ b/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -65,7 +65,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %2 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))," - "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3], is_simplified: false">[%1, %0, %i] + "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]">[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -95,7 +95,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)," - "domain: s0 in [0, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] + "domain: s0 in [0, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0 : index } @@ -109,7 +109,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," - "domain: s0 in [-10, 42], s1 in [0, 1000], is_simplified: false">[%arg0, %arg1] + "domain: s0 in [-10, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -128,7 +128,7 @@ func.func @order_summands(%arg1: index) { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< "()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)," - "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3], is_simplified: false">[%arg2, %arg1, %arg3] + "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]">[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index b301a3bbc93a74..e6fea946e6e827 100644 --- a/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -249,7 +249,7 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 4)," - "domain: d0 in [0, 9], is_simplified: false">(%i) + "domain: d0 in [0, 9]">(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -265,9 +265,9 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { #map = #xla_gpu.indexing_map< "(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)," - "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3], is_simplified: false"> + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]"> #map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," - "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3], is_simplified: false"> + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]"> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -306,8 +306,7 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, "d4 in [0, 0]," "d5 in [0, 0]," "s0 in [0, 3]," - "d0 * 4 + s0 in [0, 29]," - "is_simplified: false"> + "d0 * 4 + s0 in [0, 29]"> func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %arg3: i32, %arg4: tensor<20x30xf32>) -> tensor<20x30xf32> { %c24 = arith.constant 24 : index %c15 = arith.constant 15 : index diff --git a/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index 0c734ca19882e5..ceaa3a0748cbff 100644 --- a/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -2,7 +2,7 @@ // RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -20,7 +20,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63], is_simplified: true"> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -37,7 +37,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -61,7 +61,7 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -85,7 +85,7 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -109,7 +109,7 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -135,7 +135,7 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { // We could vectorize this as a float vector load of double the size, but we // don't currently. #map = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," - "domain: d0 in [0, 127], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 127], s0 in [0, 1]"> func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -251,11 +251,10 @@ func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { // ----- #map = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," - "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7], is_simplified: true"> + "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]"> #map1 = #xla_gpu.indexing_map< "(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512)," - "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]," - "is_simplified: true"> + "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]"> func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, %arg4: index) -> (tensor<131072xf32>, f32) { @@ -282,8 +281,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, } return %0#0, %0#1 : tensor<131072xf32>, f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7], is_simplified: true"> -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7], is_simplified: true"> +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7]"> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7]"> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -307,7 +306,7 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -335,7 +334,7 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { // ----- #map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -359,9 +358,9 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { // ----- #map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," - "domain: d0 in [0, 63], is_simplified: true"> + "domain: d0 in [0, 63]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -384,7 +383,7 @@ module { } // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2 + 10), -// CHECK-SAME: domain: d0 in [0, 63], is_simplified: true"> +// CHECK-SAME: domain: d0 in [0, 63]"> // CHECK-LABEL: @apply_indexing_sequence // CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] @@ -393,9 +392,9 @@ module { #map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," - "domain: d0 in [0, 63], is_simplified: true"> + "domain: d0 in [0, 63]"> #map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," - "domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true"> + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index d79a54f4615681..1cb92de9436e49 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -216,7 +216,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -281,7 +281,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true"> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -352,9 +352,9 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249], is_simplified: true"> -CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249], is_simplified: true"> -CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249], is_simplified: true"> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[ZERO_64:.*]] = arith.constant 0 : i64 @@ -545,8 +545,8 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047], is_simplified: true"> -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047], is_simplified: true"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr diff --git a/xla/service/gpu/model/indexing_analysis.cc b/xla/service/gpu/model/indexing_analysis.cc index 9477826b0f801f..a8f70449e10f9b 100644 --- a/xla/service/gpu/model/indexing_analysis.cc +++ b/xla/service/gpu/model/indexing_analysis.cc @@ -528,7 +528,7 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( output_shape.dimensions(), parallel_dims_sizes); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), {}, /*is_simplified=*/true); + output_shape.dimensions(), {}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce->operand_count()); @@ -661,8 +661,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for the init value. IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), /*symbol_upper_bounds=*/{}, - /*is_simplified=*/true); + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce_window->operand_count()); @@ -1154,8 +1153,7 @@ IndexingMap CreateIdentityMap(absl::Span dimensions, mlir::MLIRContext* mlir_context) { return IndexingMap::FromTensorSizes( AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), - /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}, - /*is_simplified=*/dimensions.empty()); + /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}); } IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { diff --git a/xla/service/gpu/model/indexing_analysis_test.cc b/xla/service/gpu/model/indexing_analysis_test.cc index b3f4043d73825f..be577ffa434b02 100644 --- a/xla/service/gpu/model/indexing_analysis_test.cc +++ b/xla/service/gpu/model/indexing_analysis_test.cc @@ -64,15 +64,13 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -98,29 +96,25 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))), Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))))); } @@ -159,34 +153,29 @@ TEST_F(IndexingAnalysisTest, Pair(root, ElementsAre(MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 31], - is_simplified: false + d0 in [0, 31] )"))), Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))), Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))))); } @@ -216,8 +205,7 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -261,8 +249,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: false + d3 in [0, 63] )"))), Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2), @@ -270,8 +257,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: true + d3 in [0, 63] )"))))); } @@ -291,8 +277,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -303,8 +288,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); } @@ -367,8 +351,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -379,8 +362,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -400,8 +382,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -412,8 +393,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -432,14 +412,12 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -448,8 +426,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -458,8 +435,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -483,14 +459,12 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -499,8 +473,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -509,8 +482,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -528,8 +500,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -548,8 +519,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -567,8 +537,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3), domain: d0 in [0, 50], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -577,8 +546,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { domain: d0 in [0, 15], d1 in [0, 16], - d2 in [0, 2], - is_simplified: true + d2 in [0, 2] )")); } @@ -597,8 +565,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -607,8 +574,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 19], s0 in [0, 9], - s1 in [0, 29], - is_simplified: false + s1 in [0, 29] )")); } @@ -641,22 +607,19 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 - 5, d2), domain: d0 in [0, 1], d1 in [5, 15], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 - 16, d2), domain: d0 in [0, 1], d1 in [16, 32], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -666,8 +629,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -677,8 +639,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 10], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); @@ -688,8 +649,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 16], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); } @@ -721,29 +681,25 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { (d0, d1, d2) -> (), s2 in [0, 226], hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: false + (d0, d1, d2) -> () operand id = 1 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 2 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 3 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] )")); } @@ -764,8 +720,7 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 1 (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1), domain: @@ -776,20 +731,17 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { (d0, d1) -> (), s1 in [0, 20], hlo: %of2 = s32[] parameter(3), - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () operand id = 2 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] )")); } @@ -811,13 +763,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] operand id = 1 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] )")); } @@ -891,8 +841,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0), domain: @@ -902,8 +851,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 2 (d0, d1, d2, d3, d4, d5) -> (d1), domain: @@ -912,8 +860,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] operand id = 3 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -923,8 +870,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 4 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -934,8 +880,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 5 (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5), domain: @@ -944,8 +889,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] )")); } @@ -1002,16 +946,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { d0 in [0, 1], d1 in [0, 64], d2 in [0, 124], - s0 in [0, 124], - is_simplified: true + s0 in [0, 124] )"), MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2), domain: d0 in [0, 1], d1 in [0, 64], - d2 in [0, 124], - is_simplified: true + d2 in [0, 124] )")))); } @@ -1033,15 +975,13 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")))); } @@ -1071,38 +1011,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")), UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")))); } @@ -1130,8 +1064,7 @@ TEST_F(IndexingAnalysisTest, GatherOp) { (d0, d1, d2, d3) -> (d0, 0), s1 in [0, 68], hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 1), - is_simplified: false + (d0, d1, d2, d3) -> (d0, 1) operand id = 1 (d0, d1, d2, d3)[s0] -> (d0, s0), domain: @@ -1139,8 +1072,7 @@ TEST_F(IndexingAnalysisTest, GatherOp) { d1 in [0, 6], d2 in [0, 7], d3 in [0, 3], - s0 in [0, 1], - is_simplified: false + s0 in [0, 1] )")); } @@ -1173,13 +1105,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { d0 in [0, 9], s0 in [0, 149], s1 in [0, 49], - s2 in [0, 19], - is_simplified: true + s2 in [0, 19] operand id = 1 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); } @@ -1211,14 +1141,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { domain: d0 in [0, 14], d1 in [0, 63], - s0 in [0, 19], - is_simplified: true + s0 in [0, 19] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 14], - d1 in [0, 63], - is_simplified: true + d1 in [0, 63] )")); } @@ -1253,8 +1181,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { domain: d0 in [0, 9], d1 in [0, 49], - d2 in [0, 19], - is_simplified: true + d2 in [0, 19] )")); } @@ -1286,13 +1213,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { domain: d0 in [0, 31], s0 in [0, 15], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] operand id = 1 (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1313,8 +1238,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1336,8 +1260,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1360,8 +1283,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } @@ -1386,8 +1308,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { domain: d0 in [0, 6], d1 in [0, 8], - d2 in [0, 23], - is_simplified: true + d2 in [0, 23] )")); } @@ -1434,32 +1355,27 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { (d0, d1) -> (), s3 in [0, 16], hlo: %of22 = s32[] parameter(4), - (d0, d1) -> (), - is_simplified: true + (d0, d1) -> () operand id = 1 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 4 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1488,22 +1404,19 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 * 3 - 5, d2), domain: d0 in [0, 1], d1 in [2, 5], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 * 3 - 16, d2), domain: d0 in [0, 1], d1 in [6, 10], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] )")); } @@ -1532,8 +1445,7 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 2], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 KNOWN EMPTY operand id = 2 @@ -1562,15 +1474,13 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [0, 1], - is_simplified: true + d0 * 8 + d1 in [0, 1] operand id = 1 (d0, d1) -> (d0 * 8 + d1 - 2), domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [2, 31], - is_simplified: true + d0 * 8 + d1 in [2, 31] )")); } @@ -1597,8 +1507,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { operand id = 0 (d0) -> (d0 floordiv 8, d0 mod 8), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1615,8 +1524,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { (d0, d1) -> (d0 * 8 + d1), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1635,8 +1543,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 31], d1 in [0, 2], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -1646,8 +1553,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 11], - is_simplified: true + d2 in [0, 11] )")); } @@ -1665,8 +1571,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { domain: d0 in [0, 3], d1 in [0, 3], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -1684,8 +1589,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { domain: d0 in [0, 1], d1 in [0, 3], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -1704,8 +1608,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { d1 mod 4), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1724,14 +1627,12 @@ TEST_F(IndexingAnalysisTest, PadOp) { domain: d0 in [1, 7], d1 in [4, 7], - (d0 - 1) mod 2 in [0, 0], - is_simplified: false + (d0 - 1) mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 11], - d1 in [0, 15], - is_simplified: false + d1 in [0, 15] )")); } @@ -1749,14 +1650,12 @@ TEST_F(IndexingAnalysisTest, PadOpNoInterior) { (d0, d1) -> (d0 - 1, d1), domain: d0 in [1, 2], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 9], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] )")); } @@ -1779,13 +1678,11 @@ TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { (d0) -> ((d0 + 3) floordiv 2), domain: d0 in [0, 4], - (d0 + 3) mod 2 in [0, 0], - is_simplified: false + (d0 + 3) mod 2 in [0, 0] operand id = 1 (d0) -> (), domain: - d0 in [0, 4], - is_simplified: false + d0 in [0, 4] )")); } @@ -1812,14 +1709,12 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 9], s0 in [0, 19], - s1 in [0, 49], - is_simplified: false + s1 in [0, 49] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 149], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, 0); @@ -1830,8 +1725,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 19], d2 in [0, 9], - d3 in [0, 49], - is_simplified: false + d3 in [0, 49] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, 1); EXPECT_THAT(output_indexing_1.ToString(), MatchIndexingString(R"( @@ -1839,8 +1733,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ()[s0, s1] -> (s0, s1), domain: s0 in [0, 149], - s1 in [0, 9], - is_simplified: false + s1 in [0, 9] )")); } @@ -1873,24 +1766,20 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); @@ -1899,32 +1788,27 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); constexpr std::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1), domain: d0 in [0, 255], - d1 in [0, 9], - is_simplified: false + d1 in [0, 9] )"; auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); EXPECT_THAT( @@ -1941,8 +1825,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { constexpr std::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0), domain: - s0 in [0, 9], - is_simplified: false + s0 in [0, 9] )"; auto input_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); EXPECT_THAT( @@ -1978,14 +1861,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { domain: d0 in [0, 1023], d1 in [0, 2], - s0 in [0, 511], - is_simplified: true + s0 in [0, 511] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 1023], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2014,14 +1895,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { s0 in [0, 2], s1 in [0, 1], d0 * 2 + s0 in [1, 13], - d1 + s1 in [0, 16], - is_simplified: true + d1 + s1 in [0, 16] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 6], - d1 in [0, 16], - is_simplified: true + d1 in [0, 16] )")); } @@ -2048,14 +1927,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { d0 in [0, 2], d1 in [0, 4], d0 mod 2 in [0, 0], - d1 mod 2 in [0, 0], - is_simplified: true + d1 mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 2], - d1 in [0, 4], - is_simplified: true + d1 in [0, 4] )")); } @@ -2081,14 +1958,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { domain: d0 in [0, 3], d1 in [0, 2], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2122,28 +1997,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); EXPECT_THAT(input_indexing_1.ToString(), MatchIndexingString(R"( @@ -2153,28 +2024,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); } @@ -2199,8 +2066,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2210,8 +2076,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2238,8 +2103,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { s1 in [0, 4], s2 in [0, 3], d1 * 2 + s0 in [1, 12], - d2 * 2 + s1 in [2, 11], - is_simplified: false + d2 * 2 + s1 in [2, 11] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2249,8 +2113,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2277,8 +2140,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { s1 in [0, 4], s2 in [0, 3], (d1 + s0) mod 2 in [0, 0], - (d2 + s1) mod 2 in [0, 0], - is_simplified: false + (d2 + s1) mod 2 in [0, 0] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2288,8 +2150,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2314,8 +2175,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2325,8 +2185,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2351,8 +2210,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2362,8 +2220,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2389,8 +2246,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { s0 in [0, 2], s1 in [0, 4], s2 in [0, 3], - s3 in [0, 6], - is_simplified: false + s3 in [0, 6] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2400,8 +2256,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { d3 in [0, 20], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2421,8 +2276,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -2433,8 +2287,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); } @@ -2459,8 +2312,7 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 10], - is_simplified: true + d1 in [0, 10] )")); } @@ -2480,8 +2332,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { domain: d0 in [0, 4], d1 in [0, 2], - d2 in [0, 24], - is_simplified: false + d2 in [0, 24] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -2496,8 +2347,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { d1 in [3, 17], d2 in [0, 48], (d1 - 3) mod 7 in [0, 0], - d2 mod 2 in [0, 0], - is_simplified: false + d2 mod 2 in [0, 0] )")); } @@ -2517,8 +2367,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: false + d3 in [0, 12287] )")); EXPECT_THAT(GetInputToOutputIndexing(root).ToString(), MatchIndexingString(R"( operand id = 0 @@ -2527,8 +2376,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 12287], d2 in [0, 5], - d3 in [0, 127], - is_simplified: false + d3 in [0, 127] )")); } @@ -2547,8 +2395,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -2574,8 +2421,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1), domain: @@ -2586,8 +2432,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] )")); } @@ -2648,8 +2493,7 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { (d0, d1) -> (d0 * 6, d1 * 2), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] operand id = 1 unknown indexing operand id = 2 @@ -2686,8 +2530,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { (d0, d1) -> (d1 * 1000 + d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")); } @@ -2716,8 +2559,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } @@ -2735,18 +2577,15 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { operand id = 0 (d0, d1) -> (), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 2 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } @@ -2778,14 +2617,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { s0 in [0, 4096], hlo: %slice = s32[1]{0} parameter(1), (d0, d1) -> (0), - d1 + s0 in [4096, 8191], - is_simplified: true + d1 + s0 in [4096, 8191] operand id = 1 (d0, d1) -> (0), domain: d0 in [0, 0], - d1 in [0, 4095], - is_simplified: true + d1 in [0, 4095] )")); } diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc index 8e431976467734..f7ec1f1f83dd76 100644 --- a/xla/service/gpu/model/indexing_map.cc +++ b/xla/service/gpu/model/indexing_map.cc @@ -1001,8 +1001,7 @@ std::vector RangeVarsFromTensorSizes( IndexingMap::IndexingMap( AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints, - bool is_simplified) + absl::Span const> constraints) : affine_map_(affine_map), dim_vars_(std::move(dimensions)), range_vars_(std::move(range_vars)), @@ -1014,7 +1013,6 @@ IndexingMap::IndexingMap( for (const auto& [expr, range] : constraints) { AddConstraint(expr, range); } - is_simplified_ = is_simplified; } IndexingMap::IndexingMap( @@ -1034,13 +1032,10 @@ IndexingMap::IndexingMap( IndexingMap IndexingMap::FromTensorSizes( AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, bool is_simplified) { - return IndexingMap{affine_map, - DimVarsFromTensorSizes(dim_upper_bounds), + absl::Span symbol_upper_bounds) { + return IndexingMap{affine_map, DimVarsFromTensorSizes(dim_upper_bounds), RangeVarsFromTensorSizes(symbol_upper_bounds), - /*rt_vars=*/{}, - /*constraints=*/{}, - is_simplified}; + /*rt_vars=*/{}}; } RangeEvaluator IndexingMap::GetRangeEvaluator() const { @@ -1052,7 +1047,6 @@ const Interval& IndexingMap::GetDimensionBound(int64_t dim_id) const { } Interval& IndexingMap::GetMutableDimensionBound(int64_t dim_id) { - is_simplified_ = false; return dim_vars_[dim_id].bounds; } @@ -1075,7 +1069,6 @@ const Interval& IndexingMap::GetSymbolBound(int64_t symbol_id) const { } Interval& IndexingMap::GetMutableSymbolBound(int64_t symbol_id) { - is_simplified_ = false; // Because affine map symbols are packed like [range_vars, rt_vars], // we have to pick the correct bounds. int64_t range_var_count = GetRangeVarsCount(); @@ -1131,7 +1124,6 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { ResetToKnownEmpty(); } } - is_simplified_ = false; } void IndexingMap::EraseConstraint(mlir::AffineExpr expr) { @@ -1305,7 +1297,7 @@ bool IndexingMap::Verify(std::ostream& out) const { // simplification, because the ranges of constraints were already optimized once // when IndexingMap was constructed. bool IndexingMap::Simplify() { - if (IsSimplified() || IsUndefined() || IsKnownEmpty()) return false; + if (IsUndefined() || IsKnownEmpty()) return false; bool rtvars_were_eliminated = ReplaceConstantRTVars(); @@ -1336,7 +1328,6 @@ bool IndexingMap::Simplify() { if (affine_map_was_simplified) { affine_map_ = simplified_affine_map; } - is_simplified_ = true; return affine_map_was_simplified || constraints_were_simplified || rtvars_were_eliminated; } @@ -1639,7 +1630,6 @@ void IndexingMap::ResetToKnownEmpty() { } constraints_.clear(); is_known_empty_ = true; - is_simplified_ = true; } bool IndexingMap::VerifyVariableIntervals() { @@ -2124,8 +2114,7 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { AffineMap canonical_map = affine_map_.replaceDimsAndSymbols({}, syms_replacements, num_vars, 0); IndexingMap new_indexing_map(canonical_map, new_dim_vars, /*range_vars=*/{}, - /*rt_vars=*/{}, new_constraints, - /*is_simplified=*/false); + /*rt_vars=*/{}, new_constraints); return new_indexing_map; } diff --git a/xla/service/gpu/model/indexing_map.h b/xla/service/gpu/model/indexing_map.h index 36780ddd1841e2..5751cb4c886d10 100644 --- a/xla/service/gpu/model/indexing_map.h +++ b/xla/service/gpu/model/indexing_map.h @@ -297,8 +297,7 @@ class IndexingMap { IndexingMap( mlir::AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints = {}, - bool is_simplified = false); + absl::Span const> constraints = {}); IndexingMap(mlir::AffineMap affine_map, std::vector dimensions, std::vector range_vars, std::vector rt_vars, @@ -314,8 +313,7 @@ class IndexingMap { static IndexingMap FromTensorSizes( mlir::AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, - bool is_simplified = false); + absl::Span symbol_upper_bounds); // Returns true if the indexing map is valid. bool Verify(std::ostream& out) const; @@ -397,10 +395,6 @@ class IndexingMap { // satisfies both constraints. bool IsKnownEmpty() const { return is_known_empty_; } - // Returns true if the indexing map is simplified. - void SetIsSimplified(bool is_simplified) { is_simplified_ = is_simplified; } - bool IsSimplified() const { return is_simplified_; } - bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); } // Removes unused symbols from the `affine_map_` and constraints. @@ -474,8 +468,6 @@ class IndexingMap { llvm::DenseMap constraints_; // Flag to indicate that the domain is empty. bool is_known_empty_ = false; - // Flag to indicate that the indexing map is simplified. - bool is_simplified_ = false; }; std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); diff --git a/xla/service/gpu/model/indexing_map_serialization.cc b/xla/service/gpu/model/indexing_map_serialization.cc index 4e72a4b56dd94f..3d6eb9bf1b1b23 100644 --- a/xla/service/gpu/model/indexing_map_serialization.cc +++ b/xla/service/gpu/model/indexing_map_serialization.cc @@ -293,9 +293,6 @@ Token Parser::GetNextTokenImpl() { if (spelling == "domain") { return Token{spelling, Token::Kind::kKeywordDomain}; } - if (spelling == "is_simplified") { - return Token{spelling, Token::Kind::kKeywordIsSimplified}; - } if (spelling == "in") { return Token{spelling, Token::Kind::kKeywordIn}; } @@ -599,7 +596,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ParseVarName(&var_name) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse DimVar\n"; return std::nullopt; } @@ -617,7 +615,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ParseVarName(&var_name) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse RangeVar\n"; return std::nullopt; } @@ -629,31 +628,20 @@ std::optional ParseIndexingMap(llvm::StringRef input, } // Parse constraints. SmallVector constraint_bounds; - while (!parser.ConsumeToken(Token::Kind::kKeywordIsSimplified)) { + while (!parser.ConsumeToken(Token::Kind::kEOF)) { std::string affine_expr_str; Interval interval; if (!parser.ParseAffineExprString(&affine_expr_str) || !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || - !parser.ConsumeToken(Token::Kind::kComma)) { + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { llvm::errs() << "Failed to parse constraint\n"; return std::nullopt; } affine_expr_strs.push_back(affine_expr_str); constraint_bounds.push_back(interval); } - // Parse is_simplified. - bool is_simplified; - if (!parser.ConsumeToken(Token::Kind::kColon) || - !parser.ParseBool(&is_simplified)) { - llvm::errs() << "Failed to parse is_simplified\n"; - return std::nullopt; - } - // Check that the input is consumed. - if (!parser.ConsumeToken(Token::Kind::kEOF)) { - return std::nullopt; - } - // Parse affine expressions. SmallVector affine_exprs; if (!ParseAffineExprsWithMLIR(dim_var_names, symbol_var_names, @@ -674,9 +662,8 @@ std::optional ParseIndexingMap(llvm::StringRef input, } auto map = AffineMap::get(dim_vars.size(), range_vars.size(), affine_map_results, context); - return IndexingMap{ - map, std::move(dim_vars), std::move(range_vars), /*rt_vars=*/{}, - constraints, is_simplified}; + return IndexingMap{map, std::move(dim_vars), std::move(range_vars), + /*rt_vars=*/{}, constraints}; } std::string ToString(AffineExpr affine_expr) { @@ -782,18 +769,29 @@ std::string ToString(const IndexingMap& indexing_map, return ss.str(); } ss << ", domain: "; + int64_t remaining_vars_to_print = + dim_vars.size() + range_vars.size() + rt_vars.size(); for (const auto& [index, dim_var] : llvm::enumerate(dim_vars)) { - ss << dim_names[index] << " in " << dim_var.bounds << ", "; + ss << dim_names[index] << " in " << dim_var.bounds; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } for (const auto& [index, range_var] : llvm::enumerate(range_vars)) { - ss << symbol_names[index] << " in " << range_var.range << ", "; + ss << symbol_names[index] << " in " << range_var.range; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } int64_t num_range_vars = range_vars.size(); for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { ss << GetSymbolName(num_range_vars + index, symbol_names) << " in " << rt_var.feasible_values << ", hlo: " << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", " - << ToString(rt_var.map) << ", "; + << ToString(rt_var.map); + if (--remaining_vars_to_print > 0) { + ss << ", "; + } } std::vector expr_range_strings; const auto& constraints = indexing_map.GetConstraints(); @@ -803,10 +801,9 @@ std::string ToString(const IndexingMap& indexing_map, ToString(expr, dim_names, symbol_names), " in ", range.ToString())); } std::sort(expr_range_strings.begin(), expr_range_strings.end()); - for (const auto& expr_range_string : expr_range_strings) { - ss << expr_range_string << ", "; + if (!expr_range_strings.empty()) { + ss << ", " << absl::StrJoin(expr_range_strings, ", "); } - ss << "is_simplified: " << (indexing_map.IsSimplified() ? "true" : "false"); return ss.str(); } diff --git a/xla/service/gpu/model/indexing_map_serialization_test.cc b/xla/service/gpu/model/indexing_map_serialization_test.cc index 28a7f7b60b4ac8..98fff83cc277fd 100644 --- a/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -48,8 +48,7 @@ TEST_F(IndexingMapSerializationTest, DimsOnly) { (d0, d1) -> (d0 mod 2 + d1), domain: d0 in [0, 3], - d1 in [-4, 4], - is_simplified: true + d1 in [-4, 4] )"); } @@ -58,8 +57,7 @@ TEST_F(IndexingMapSerializationTest, SymbolsOnly) { ()[s0, s1] -> (s0 floordiv s1), domain: s0 in [0, 3], - s1 in [0, 4], - is_simplified: true + s1 in [0, 4] )"); } @@ -71,8 +69,7 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsNoConstraints) { d1 in [0, 4], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )"); } @@ -86,8 +83,7 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) { s1 in [0, 1], s2 in [0, 3], d0 mod 4 in [0, 0], - d1 + s0 in [0, 45], - is_simplified: false + d1 + s0 in [0, 45] )"); } @@ -99,8 +95,7 @@ TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { d0 in [0, 9], d1 in [0, 19], s0 in [0, 29], - s1 in [0, 39], - is_simplified: false + s1 in [0, 39] )"); } @@ -116,8 +111,7 @@ TEST_F(IndexingMapSerializationTest, CustomNames) { reduced_dim in [0, 1], contracted_dim in [0, 3], th_x mod 4 in [0, 0], - bl_x + vector_elem in [0, 45], - is_simplified: false + bl_x + vector_elem in [0, 45] )"; auto indexing_map_golden = R"( (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), @@ -128,8 +122,7 @@ TEST_F(IndexingMapSerializationTest, CustomNames) { s1 in [0, 1], s2 in [0, 3], d0 mod 4 in [0, 0], - d1 + s0 in [0, 45], - is_simplified: false + d1 + s0 in [0, 45] )"; auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); ASSERT_TRUE(indexing_map.has_value()); diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc index 8fd369b5b6e596..b7e45e141b6af6 100644 --- a/xla/service/gpu/model/indexing_map_test.cc +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -136,8 +136,7 @@ TEST_F(IndexingMapTest, RTVar) { (d0, d1) -> (), rt_1 in [0, 7], hlo: NULL, - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () )")); } @@ -148,8 +147,7 @@ TEST_F(IndexingMapTest, Evaluation) { d0 in [0, 3], d1 in [0, 3], s0 in [0, 1], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); auto results = indexing_map.Evaluate( mlir::getAffineConstantExprs({1, 2}, &mlir_context_), @@ -177,16 +175,14 @@ TEST_F(IndexingMapTest, Composition_Permutation) { d0 in [0, 3], d1 in [0, 3], s0 in [0, 1], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); IndexingMap consumer = Parse(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 3], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -196,8 +192,7 @@ TEST_F(IndexingMapTest, Composition_Permutation) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -208,16 +203,14 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { d0 in [0, 4], d1 in [0, 5], s0 in [0, 6], - s1 in [0, 1], - is_simplified: false + s1 in [0, 1] )"); IndexingMap consumer = Parse(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 9], - s0 in [0, 7], - is_simplified: false + s0 in [0, 7] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -227,8 +220,7 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { d0 in [0, 4], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -241,8 +233,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s0 in [0, 69], s1 in [0, 19], d0 mod 8 in [0, 0], - s0 mod 3 in [1, 1], - is_simplified: false + s0 mod 3 in [1, 1] )"); IndexingMap consumer = Parse(R"( @@ -251,8 +242,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { d0 in [0, 9], s0 in [0, 7], d0 + s0 in [0, 20], - s0 mod 4 in [0, 0], - is_simplified: false + s0 mod 4 in [0, 0] )"); auto composed = ComposeIndexingMaps(consumer, producer); @@ -266,8 +256,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { d0 + s2 in [0, 20], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: false + s2 mod 4 in [0, 0] )")); EXPECT_TRUE(composed.Simplify()); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -279,8 +268,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s2 in [0, 4], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: true + s2 mod 4 in [0, 0] )")); } @@ -319,8 +307,7 @@ TEST_F(IndexingMapTest, Composition_RTVar) { (d0, d1) -> (), rt_2 in [0, 226], hlo: NULL, - (d0, d1) -> (), - is_simplified: false + (d0, d1) -> () )")); } @@ -365,8 +352,7 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { hlo: NULL, (d0, d1) -> (), d0 + cs_0 * 2 in [0, 24], - d1 + cs_1 * 3 in [0, 15], - is_simplified: false + d1 + cs_1 * 3 in [0, 15] )")); } @@ -380,8 +366,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { s0 in [0, 69], s1 in [0, 19], d0 + s0 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -392,8 +377,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { s0 in [0, 69], s1 in [0, 19], d0 + s0 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } @@ -406,8 +390,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - d0 mod 3 in [0, 0], - is_simplified: false + d0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -415,8 +398,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { domain: d0 in [0, 59], s0 in [0, 69], - s1 in [0, 19], - is_simplified: false + s1 in [0, 19] )")); } @@ -429,8 +411,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -438,8 +419,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } @@ -456,8 +436,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { s1 in [0, 63], s2 in [0, 95], s0 * 4 + d1 + d3 in [24, 459], - s0 + s2 in [0, 512], - is_simplified: false + s0 + s2 in [0, 512] )"); // dimensions d0, d2, d4 and symbol s1 will be removed. auto unused_vars = indexing_map.RemoveUnusedVars(); @@ -469,8 +448,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { s0 in [0, 31], s1 in [0, 95], d0 + s0 * 4 + d1 in [24, 459], - s0 + s1 in [0, 512], - is_simplified: false + s0 + s1 in [0, 512] )")); EXPECT_THAT(ConvertToSTL(unused_vars), ::testing::ElementsAreArray( @@ -486,8 +464,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { s0 in [0, 69], s1 in [0, 19], s0 + s1 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); // This constraint cannot be removed, because it contains a "used symbol". indexing_map.RemoveUnusedSymbols(); @@ -499,8 +476,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { s0 in [0, 69], s1 in [0, 19], s0 + s1 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } @@ -512,8 +488,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); // This constraint can be removed, because it contains only the unused symbol. indexing_map.RemoveUnusedSymbols(); @@ -522,8 +497,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } @@ -532,14 +506,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { (d0) -> (d0), domain: d0 in [0, 49], - 0 in [-10, 5], - is_simplified: false + 0 in [-10, 5] )"); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 49], - is_simplified: false + d0 in [0, 49] )")); } @@ -547,8 +519,7 @@ TEST_F(IndexingMapTest, KnownEmpty_CreatingIndexingMapWithInfeasibleRange) { auto indexing_map = Parse(R"( (d0) -> (d0), domain: - d0 in [0, -2], - is_simplified: false + d0 in [0, -2] )"); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } @@ -558,20 +529,15 @@ TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRange) { (d0) -> (d0), domain: d0 in [0, 49], - 0 in [10, 15], - is_simplified: false + 0 in [10, 15] )"); // Addition of this constraint makes the domain empty. EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_Composition) { - auto indexing_map = Parse(R"( - (d0) -> (d0), domain: d0 in [0, 49], is_simplified: false - )"); - auto known_empty = Parse(R"( - (d0) -> (d0), domain: d0 in [0, -1], is_simplified: false - )"); + auto indexing_map = Parse("(d0) -> (d0), domain: d0 in [0, 49]"); + auto known_empty = Parse("(d0) -> (d0), domain: d0 in [0, -1]"); EXPECT_THAT(known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(indexing_map * known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(known_empty * indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -588,8 +554,7 @@ TEST_F(IndexingMapTest, d1 in [0, 59], s0 in [0, 69], s1 in [0, 19], - s1 floordiv 20 in [2, 2], - is_simplified: false + s1 floordiv 20 in [2, 2] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -605,8 +570,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { s2 in [0, 2], s3 in [0, 3], s4 in [0, 4], - d0 * 4 + s1 + s3 in [24, 459], - is_simplified: false + d0 * 4 + s1 + s3 in [24, 459] )"); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. @@ -616,8 +580,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { d0 in [0, 31], s0 in [0, 1], s1 in [0, 3], - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + d0 * 4 + s0 + s1 in [24, 459] )")); } @@ -644,8 +607,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { s1 in [0, 3], hlo: NULL, (d0) -> (), - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + d0 * 4 + s0 + s1 in [24, 459] )")); }; @@ -669,8 +631,7 @@ TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { d2 in [0, 1], d3 in [0, 3], d4 in [0, 4], - d0 * 4 + d1 + d3 * 2 in [24, 459], - is_simplified: false + d0 * 4 + d1 + d3 * 2 in [24, 459] )")); } @@ -679,16 +640,14 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { (d0) -> (d0), domain: d0 in [0, 99], - d0 mod 8 + 5 in [50, 54], - is_simplified: false + d0 mod 8 + 5 in [50, 54] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [0, 99], - d0 mod 8 in [45, 49], - is_simplified: true + d0 mod 8 in [45, 49] )")); } @@ -700,8 +659,7 @@ TEST_F(IndexingMapTest, d0 in [0, 1999], s0 in [0, 1], s1 in [0, 2], - d0 * 6 + s0 * 3 + s1 in [0, 599], - is_simplified: false + d0 * 6 + s0 * 3 + s1 in [0, 599] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -709,8 +667,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [0, 1], - s1 in [0, 2], - is_simplified: true + s1 in [0, 2] )")); } @@ -722,8 +679,7 @@ TEST_F(IndexingMapTest, d0 in [0, 1999], s0 in [0, 1], s1 in [0, 2], - d0 * 6 + s0 * 3 + s1 in [0, 598], - is_simplified: false + d0 * 6 + s0 * 3 + s1 in [0, 598] )"); EXPECT_FALSE(indexing_map.Simplify()); } @@ -734,16 +690,14 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { domain: d0 in [0, 1999], s0 in [0, 1], - d0 * 6 + s0 * 3 in [0, 599], - is_simplified: false + d0 * 6 + s0 * 3 in [0, 599] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), domain: d0 in [0, 99], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] )")); } @@ -753,15 +707,13 @@ TEST_F(IndexingMapTest, (d0) -> (d0), domain: d0 in [0, 99], - d0 floordiv 8 in [5, 11], - is_simplified: false + d0 floordiv 8 in [5, 11] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [40, 95], - is_simplified: true + d0 in [40, 95] )")); } @@ -772,16 +724,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 floordiv 3 in [-11, -5], - is_simplified: false + s0 floordiv 3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-33, -13], - is_simplified: true + s0 in [-33, -13] )")); } @@ -792,16 +742,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 floordiv -3 in [-11, -5], - is_simplified: false + s0 floordiv -3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [15, 35], - is_simplified: true + s0 in [15, 35] )")); } @@ -811,15 +759,13 @@ TEST_F(IndexingMapTest, (d0) -> (d0), domain: d0 in [0, 99], - d0 * 8 in [14, 33], - is_simplified: false + d0 * 8 in [14, 33] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [2, 4], - is_simplified: true + d0 in [2, 4] )")); } @@ -830,16 +776,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 * 3 in [-11, -5], - is_simplified: false + s0 * 3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-3, -2], - is_simplified: true + s0 in [-3, -2] )")); } @@ -850,16 +794,14 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 99], s0 in [-99, 99], - s0 * -3 in [-11, -5], - is_simplified: false + s0 * -3 in [-11, -5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [2, 3], - is_simplified: true + s0 in [2, 3] )")); } @@ -873,8 +815,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { d0 mod 3 in [0, 0], s0 mod 2 in [0, 0], s0 mod 3 in [0, 0], - s1 mod 5 in [1, 1], - is_simplified: false + s1 mod 5 in [1, 1] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -885,8 +826,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { s1 in [1, 6], d0 mod 3 in [0, 0], s0 mod 6 in [0, 0], - s1 mod 5 in [1, 1], - is_simplified: true + s1 mod 5 in [1, 1] )")); } @@ -894,15 +834,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { auto indexing_map = Parse(R"( (d0) -> (d0), domain: - d0 in [5, 5], - is_simplified: false + d0 in [5, 5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (5), domain: - d0 in [5, 5], - is_simplified: true + d0 in [5, 5] )")); } @@ -916,8 +854,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { d0 in [0, 9], d1 in [0, 19], s0 in [0, 29], - s1 in [0, 39], - is_simplified: false + s1 in [0, 39] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); @@ -930,8 +867,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { (d0)[s0] -> ((((s0 + d0) + d0) floordiv 2)), domain: d0 in [0, 9], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); @@ -942,16 +878,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { (d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6), domain: d0 in [0, 11], - d1 in [0, 5], - is_simplified: false + d1 in [0, 5] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), domain: d0 in [0, 11], - d1 in [0, 5], - is_simplified: true + d1 in [0, 5] )")); } @@ -959,15 +893,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { auto indexing_map = Parse(R"( (d0) -> (d0 mod 42), domain: - d0 in [53, 71], - is_simplified: false + d0 in [53, 71] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 - 42), domain: - d0 in [53, 71], - is_simplified: true + d0 in [53, 71] )")); } @@ -975,24 +907,20 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { auto indexing_map = Parse(R"( (d0) -> (d0 mod 5), domain: - d0 in [-5, -1], - is_simplified: false + d0 in [-5, -1] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 + 5), domain: - d0 in [-5, -1], - is_simplified: true + d0 in [-5, -1] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsNotAdd) { - auto indexing_map1 = - Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0], is_simplified: false"); + auto indexing_map1 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0]"); EXPECT_FALSE(indexing_map1.Simplify()); - auto indexing_map2 = - Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1], is_simplified: false"); + auto indexing_map2 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1]"); EXPECT_FALSE(indexing_map2.Simplify()); } @@ -1001,16 +929,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { (d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1019,16 +945,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { (d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1037,16 +961,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { (d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: false + s0 in [0, 3] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } @@ -1056,16 +978,14 @@ TEST_F(IndexingMapTest, (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: false + d1 in [0, 15] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1077,8 +997,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { domain: d0 in [0, 8], d1 in [0, 8], - d2 in [0, 8], - is_simplified: false + d2 in [0, 8] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1086,8 +1005,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { domain: d0 in [0, 8], d1 in [0, 8], - d2 in [0, 8], - is_simplified: true + d2 in [0, 8] )")); } @@ -1099,8 +1017,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: false + d2 in [0, 9] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1109,8 +1026,7 @@ TEST_F(IndexingMapTest, domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } @@ -1120,16 +1036,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99), domain: d0 in [0, 7], - d1 in [0, 8], - is_simplified: false + d1 in [0, 8] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 8], - is_simplified: true + d1 in [0, 8] )")); } @@ -1137,15 +1051,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715), domain: - s0 in [0, 127], - is_simplified: false + s0 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0 * 128), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } @@ -1154,8 +1066,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { (d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024), domain: d0 in [0, 1023], - d1 in [0, 127], - is_simplified: false + d1 in [0, 127] )"); ; EXPECT_TRUE(indexing_map.Simplify()); @@ -1163,8 +1074,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { (d0, d1) -> (d0 * 128 + d1), domain: d0 in [0, 1023], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } @@ -1174,16 +1084,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { + ((d1 * 128 + d0) floordiv 192) * 768), domain: d0 in [0, 127], - d1 in [0, 3071], - is_simplified: false + d1 in [0, 3071] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), domain: d0 in [0, 127], - d1 in [0, 3071], - is_simplified: true + d1 in [0, 3071] )")); } @@ -1192,15 +1100,13 @@ TEST_F(IndexingMapTest, auto indexing_map = Parse(R"( (d0) -> ((-d0) mod 2), domain: - d0 in [0, 127], - is_simplified: false + d0 in [0, 127] )"); EXPECT_FALSE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1215,16 +1121,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { + ((d0 * 2 + d1 floordiv 64) mod 3) * 256 + (d1 mod 64) * 4), domain: d0 in [0, 3071], - d1 in [0, 127], - is_simplified: false + d1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), domain: d0 in [0, 3071], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } @@ -1233,15 +1137,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715), domain: - s0 in [0, 127], - is_simplified: false + s0 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } @@ -1249,15 +1151,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { auto indexing_map = Parse(R"( ()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * 14), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } @@ -1266,16 +1166,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { ()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } @@ -1283,15 +1181,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 * 6 + 9) floordiv 18), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } @@ -1300,8 +1196,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { ()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. EXPECT_FALSE(indexing_map.Simplify()); @@ -1314,8 +1209,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { auto indexing_map = Parse(R"( ()[s0] -> ((s0 floordiv 2) floordiv -7), domain: - s0 in [0, 1233], - is_simplified: false + s0 in [0, 1233] )"); EXPECT_FALSE(indexing_map.Simplify()); } @@ -1327,8 +1221,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { s0 in [0, 871], s1 in [0, 3], s2 in [0, 127], - s3 in [0, 895], - is_simplified: false + s3 in [0, 895] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1339,8 +1232,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { s0 in [0, 871], s1 in [0, 3], s2 in [0, 127], - s3 in [0, 895], - is_simplified: true + s3 in [0, 895] )")); } @@ -1351,8 +1243,7 @@ TEST_F(IndexingMapTest, floordiv 4), domain: s0 in [0, 1], - s1 in [0, 127], - is_simplified: false + s1 in [0, 127] )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1361,8 +1252,7 @@ TEST_F(IndexingMapTest, ), domain: s0 in [0, 1], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } @@ -1374,8 +1264,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { s0 in [0, 6], s1 in [0, 1], s2 in [0, 5], - s0 mod 6 in [0, 0], - is_simplified: false + s0 mod 6 in [0, 0] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1384,8 +1273,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1397,8 +1285,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { s0 in [0, 41], s1 in [0, 1], s2 in [0, 5], - s0 mod 6 in [3, 3], - is_simplified: false + s0 mod 6 in [3, 3] )"); // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 @@ -1409,8 +1296,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { d0 in [0, 3], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1423,8 +1309,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { s1 in [0, 1], s2 in [0, 5], s0 mod 2 in [0, 0], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1433,8 +1318,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } @@ -1447,8 +1331,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { s1 in [0, 1], s2 in [0, 5], s0 * s2 in [0, 28], - s0 mod 6 in [3, 3], - is_simplified: false + s0 mod 6 in [3, 3] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( @@ -1458,8 +1341,7 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { s0 in [0, 1], s1 in [0, 1], s2 in [0, 5], - (s0 * 6 + 3) * s2 in [0, 28], - is_simplified: false + (s0 * 6 + 3) * s2 in [0, 28] )")); } @@ -1473,8 +1355,7 @@ TEST_F(IndexingMapTest, s1 in [0, 1], s2 in [0, 5], s0 mod 6 in [3, 3], - s0 mod 7 in [5, 5], - is_simplified: false + s0 mod 7 in [5, 5] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1510,8 +1391,7 @@ TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { s1 in [0, 1], s2 in [0, 5], s0 mod 6 in [0, 0], - s0 * s1 in [0, 100], - is_simplified: false + s0 * s1 in [0, 100] )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1528,8 +1408,7 @@ TEST_F(IndexingMapTest, RangeEvaluatorTest) { d0 in [0, 9], d1 in [-10, -1], d2 in [-1, 2], - d3 in [0, 0], - is_simplified: false + d3 in [0, 0] )"); RangeEvaluator range_evaluator(indexing_map, &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; @@ -1784,8 +1663,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0), domain: - d0 in [0, 255], - is_simplified: true + d0 in [0, 255] )")); } @@ -1815,8 +1693,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 7), domain: - d0 in [0, 255], - is_simplified: true + d0 in [0, 255] )")); } @@ -1849,8 +1726,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { (d0) -> (d0, d0), domain: d0 in [0, 254], - d0 mod 2 in [0, 0], - is_simplified: true + d0 mod 2 in [0, 0] )")); } @@ -1883,8 +1759,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, 11), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1926,8 +1801,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (d0 floordiv 12) * -4 + 8), domain: - d0 in [0, 35], - is_simplified: true + d0 in [0, 35] )")); } @@ -1963,8 +1837,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { d0 in [0, 23], s0 in [0, 512], hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0 floordiv 2), - is_simplified: true + (d0) -> (d0 floordiv 2) )")); } @@ -1999,8 +1872,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, d0 * 2 + 42), domain: - d0 in [0, 11], - is_simplified: true + d0 in [0, 11] )")); } @@ -2040,8 +1912,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0, (-d0 + 11) * d0), domain: - d0 in [0, 11], - is_simplified: true + d0 in [0, 11] )")); } @@ -2080,8 +1951,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { d0 in [0, 11], s0 in [0, 11], hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0), - is_simplified: true + (d0) -> (d0) )")); } @@ -2173,8 +2043,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0), @@ -2182,8 +2051,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2191,8 +2059,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 50], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2200,8 +2067,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { d0 in [0, 49], d1 in [0, 59], s0 in [0, 69], - s1 in [0, 79], - is_simplified: false + s1 in [0, 79] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2211,8 +2077,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { s0 in [0, 69], s1 in [0, 79], d0 mod 8 in [0, 0], - d0 mod 16 in [0, 0], - is_simplified: false + d0 mod 16 in [0, 0] )"), Parse(R"( (d0, d1)[s0, s1] -> (d1, d0, s1, s0), @@ -2222,8 +2087,7 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { s0 in [0, 69], s1 in [0, 79], d0 mod 8 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: false + d0 mod 32 in [0, 0] )"), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index db62388f89f099..d9607223fae319 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -169,8 +169,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto p0_from_subtract0 = root->operand(0); @@ -183,8 +182,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); EXPECT_THAT(*p0_from_subtract1, MatchTiledHloInstruction( @@ -194,8 +192,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); } @@ -287,8 +284,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 0], - is_simplified: true + d1 in [0, 0] )")); } @@ -322,8 +318,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); EXPECT_THAT(*root->operand(0), @@ -334,8 +329,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -372,8 +366,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice0, @@ -383,8 +376,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2 + 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice1, @@ -394,8 +386,7 @@ ENTRY main { (d0, d1) -> (d0 * 2 + 3, d1 * 2 + 4), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); } @@ -430,8 +421,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); const TiledHloInstruction* lhs = dot->operand(0); @@ -441,8 +431,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, 0), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); const TiledHloInstruction* rhs = dot->operand(1); @@ -452,8 +441,7 @@ ENTRY main { (d0, d1) -> (0, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -911,8 +899,7 @@ ENTRY main { (d0, d1) -> (d0, d1), domain: d0 in [0, 65537], - d1 in [0, 32767], - is_simplified: true + d1 in [0, 32767] )")); } @@ -967,8 +954,7 @@ ENTRY main { (d0, d1) -> (0, d1, 0), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); EXPECT_THAT(*param_0_tile, MatchTiledHloInstruction( @@ -984,8 +970,7 @@ ENTRY main { (d0, d1, d2) -> (), s1 in [0, 226], hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: true + (d0, d1, d2) -> () )")); }