Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Unify parsers for IndexingMap and IndexingMa…
Browse files Browse the repository at this point in the history
…pAttr.

Unfortunately, MLIR does not support multiline string attributes right now, so the lit tests don't look as pretty as before.

PiperOrigin-RevId: 678637171
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 25, 2024
1 parent e56fee4 commit 5497915
Show file tree
Hide file tree
Showing 21 changed files with 511 additions and 724 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ cc_library(
":xla_gpu_ops_inc_gen",
":xla_gpu_types_inc_gen",
"//xla/service/gpu/model:indexing_analysis",
"@com_google_absl//absl/strings:str_format",
"//xla/service/gpu/model:indexing_map_serialization",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BytecodeOpInterface",
Expand Down
110 changes: 56 additions & 54 deletions xla/service/gpu/fusions/ir/tests/attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
// CHECK-SAME: s0 in [0, 32],
// CHECK-SAME: d0 + s0 in [1, 10],
// CHECK-SAME: d0 mod 2 in [0, 1],
// CHECK-SAME: is_simplified: true
// CHECK-SAME: is_simplified: true"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
domain:
d0 in [1, 2],
d1 in [5, 8],
d2 in [10, 12],
s0 in [0, 32],
d0 mod 2 in [0, 1],
d0 + s0 in [1, 10],
is_simplified: true
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
"d0 in [1, 2],"
"d1 in [5, 8],"
"d2 in [10, 12],"
"s0 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"is_simplified: true"
>

func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>)
Expand All @@ -39,20 +39,21 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>
// CHECK-SAME: d0 + s0 in [1, 10]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: d1 + s1 + s2 in [1, 32]
// CHECK-SAME: is_simplified: false
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),
domain:
d0 in [1, 2],
d1 in [5, 8],
s0 in [0, 10],
s1 in [0, 5],
s2 in [0, 32],
d0 mod 2 in [0, 1],
d0 + s0 in [1, 10],
d1 + s1 + s2 in [1, 32],
is_simplified: false
>
#map = #xla_gpu.indexing_map<
"(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),"
"domain:"
"d0 in [1, 2],"
"d1 in [5, 8],"
"s0 in [0, 10],"
"s1 in [0, 5],"
"s2 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"d1 + s1 + s2 in [1, 32],"
"is_simplified: false"
>
func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-LABEL: @more_range_vars
// CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]>
Expand All @@ -64,13 +65,13 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [0, 100]
// CHECK-SAME: s0 in [-3, -1]
// CHECK-SAME: is_simplified: false
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0),
domain:
d0 in [0, 100],
s0 in [-3, -1],
is_simplified: false
#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0),"
"domain:"
"d0 in [0, 100],"
"s0 in [-3, -1],"
"is_simplified: false"
>
func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @indexing_map_small
Expand All @@ -85,15 +86,15 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: d1 in [5, 8]
// CHECK-SAME: d2 in [10, 12]
// CHECK-SAME: s0 in [0, 32]
// CHECK-SAME: is_simplified: false
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
domain:
d0 in [1, 2],
d1 in [5, 8],
d2 in [10, 12],
s0 in [0, 32],
is_simplified: false
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
"d0 in [1, 2],"
"d1 in [5, 8],"
"d2 in [10, 12],"
"s0 in [0, 32],"
"is_simplified: false"
>
func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-LABEL: @no_constraints
Expand All @@ -106,13 +107,13 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: s0 in [3, 5]
// CHECK-SAME: s0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<()[s0] -> (s0),
domain:
s0 in [3, 5],
s0 mod 2 in [0, 1],
is_simplified: false
#map = #xla_gpu.indexing_map<"()[s0] -> (s0),"
"domain:"
"s0 in [3, 5],"
"s0 mod 2 in [0, 1],"
"is_simplified: false"
>
func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_dimensions
Expand All @@ -125,13 +126,13 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [3, 5]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<(d0) -> (d0),
domain:
d0 in [3, 5],
d0 mod 2 in [0, 1],
is_simplified: false
#map = #xla_gpu.indexing_map<"(d0) -> (d0),"
"domain:"
"d0 in [3, 5],"
"d0 mod 2 in [0, 1],"
"is_simplified: false"
>
func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_symbols
Expand All @@ -142,7 +143,7 @@ func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
// CHECK-SAME: () -> ()
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<() -> ()>
#map = #xla_gpu.indexing_map<"() -> ()">
func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @empty
// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
Expand All @@ -151,7 +152,8 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)

func.func private @tensor_layout(
%in0: tensor<42xf32, #xla_gpu.layout<"shmem",
(d0) -> (), domain: d0 in [0, 42], is_simplified: true>>)
// CHECK: #layout = #xla_gpu.layout<"shmem", (d0) -> (),
// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true>
// CHECK: tensor<42xf32, #layout>
"(d0) -> (),"
"domain: d0 in [0, 42], is_simplified: true">>)
// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (),
// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true">
// CHECK: tensor<42xf32, #layout>
Loading

0 comments on commit 5497915

Please sign in to comment.