Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Remove is_simplified flag.
Browse files Browse the repository at this point in the history
The benchmarks don't show a lot of improvements in compile time.

PiperOrigin-RevId: 679573820
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 27, 2024
1 parent 6c3194e commit d539254
Show file tree
Hide file tree
Showing 30 changed files with 500 additions and 902 deletions.
30 changes: 8 additions & 22 deletions xla/service/gpu/fusions/ir/tests/attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
// CHECK-SAME: d2 in [10, 12],
// CHECK-SAME: s0 in [0, 32],
// CHECK-SAME: d0 + s0 in [1, 10],
// CHECK-SAME: d0 mod 2 in [0, 1],
// CHECK-SAME: is_simplified: true"
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
Expand All @@ -18,8 +17,7 @@
"d2 in [10, 12],"
"s0 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"is_simplified: true"
"d0 + s0 in [1, 10]"
>

func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>)
Expand All @@ -39,7 +37,6 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>
// CHECK-SAME: d0 + s0 in [1, 10]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: d1 + s1 + s2 in [1, 32]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<
"(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),"
Expand All @@ -51,8 +48,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>
"s2 in [0, 32],"
"d0 mod 2 in [0, 1],"
"d0 + s0 in [1, 10],"
"d1 + s1 + s2 in [1, 32],"
"is_simplified: false"
"d1 + s1 + s2 in [1, 32]"
>
func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-LABEL: @more_range_vars
Expand All @@ -65,13 +61,11 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [0, 100]
// CHECK-SAME: s0 in [-3, -1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0),"
"domain:"
"d0 in [0, 100],"
"s0 in [-3, -1],"
"is_simplified: false"
"s0 in [-3, -1]"
>
func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @indexing_map_small
Expand All @@ -86,15 +80,13 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: d1 in [5, 8]
// CHECK-SAME: d2 in [10, 12]
// CHECK-SAME: s0 in [0, 32]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0),"
"domain:"
"d0 in [1, 2],"
"d1 in [5, 8],"
"d2 in [10, 12],"
"s0 in [0, 32],"
"is_simplified: false"
"s0 in [0, 32]"
>
func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-LABEL: @no_constraints
Expand All @@ -107,13 +99,11 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: s0 in [3, 5]
// CHECK-SAME: s0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"()[s0] -> (s0),"
"domain:"
"s0 in [3, 5],"
"s0 mod 2 in [0, 1],"
"is_simplified: false"
"s0 mod 2 in [0, 1]"
>
func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_dimensions
Expand All @@ -126,13 +116,11 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-SAME: domain:
// CHECK-SAME: d0 in [3, 5]
// CHECK-SAME: d0 mod 2 in [0, 1]
// CHECK-SAME: is_simplified: false"
// CHECK-SAME: >
#map = #xla_gpu.indexing_map<"(d0) -> (d0),"
"domain:"
"d0 in [3, 5],"
"d0 mod 2 in [0, 1],"
"is_simplified: false"
>
func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
// CHECK-LABEL: @no_symbols
Expand All @@ -152,8 +140,6 @@ func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)

func.func private @tensor_layout(
%in0: tensor<42xf32, #xla_gpu.layout<"shmem",
"(d0) -> (),"
"domain: d0 in [0, 42], is_simplified: true">>)
// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (),
// CHECK-SAME: domain: d0 in [0, 42], is_simplified: true">
"(d0) -> ()," "domain: d0 in [0, 42]">>)
// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), domain:
// CHECK: tensor<42xf32, #layout>
14 changes: 6 additions & 8 deletions xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ using mlir::success;
// Parses a chain of string attributes into an indexing map.
// Example:
// "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2),"
// " domain: s0 in [-10, 10], s1 in [0, 2],"
// " is_simplified: false"
// " domain: s0 in [-10, 10], s1 in [0, 2]"
// will be parsed as 3 StringAttrs, concatenated into a single string, and then
// parsed into an IndexingMap.
std::optional<IndexingMap> parseChainOfStringsAsIndexingMap(
Expand Down Expand Up @@ -84,17 +83,16 @@ IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context,
constraints.push_back({constraint.first, constraint.second});
}
return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(),
indexing_map.GetRangeVars(), constraints,
indexing_map.IsSimplified());
indexing_map.GetRangeVars(), constraints);
}

mlir::LogicalResult IndexingMapAttr::verify(
mlir::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
ArrayRef<RangeVar> range_vars,
ArrayRef<std::pair<AffineExpr, Interval>> constraints, bool is_simplified) {
auto indexing_map = IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{},
constraints, is_simplified);
ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
auto indexing_map =
IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints);
std::stringstream ss;
if (!indexing_map.Verify(ss)) {
return emitError() << ss.str();
Expand All @@ -104,7 +102,7 @@ mlir::LogicalResult IndexingMapAttr::verify(

IndexingMap IndexingMapAttr::getIndexingMap() const {
return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{},
getConstraints(), getIsSimplified());
getConstraints());
}

int64_t IndexingMapAttr::getNumResults() const {
Expand Down
5 changes: 1 addition & 4 deletions xla/service/gpu/fusions/ir/xla_gpu_attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar",
"RangeVarArray"> {
}

def XLAGPU_BoolParameter : AttrOrTypeParameter<"bool", ""> {}

def XLAGPU_ConstraintsParameter :
ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>",
"ContraintsArray"> {
Expand All @@ -52,8 +50,7 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> {
let parameters = (ins XLAGPU_AffineMapParameter:$map,
XLAGPU_DimVarsParameter:$dim_vars,
XLAGPU_RangeVarsParameter:$range_vars,
XLAGPU_ConstraintsParameter:$constraints,
XLAGPU_BoolParameter:$is_simplified);
XLAGPU_ConstraintsParameter:$constraints);
let hasCustomAssemblyFormat = 1;
let builders = [
AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>,
Expand Down
15 changes: 7 additions & 8 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,13 @@ ParseResult ApplyIndexingOp::parse(OpAsmParser& parser,
parser.parseOptionalAttrDict(result.attributes)) {
return failure();
}
auto map = indexing_map_attr.getMap();
auto map = indexing_map_attr.getIndexingMap().GetAffineMap();
result.addTypes(SmallVector<Type, 2>(map.getNumResults(), index_type));
return success();
}

void ApplyIndexingOp::print(OpAsmPrinter& p) {
AffineMap affine_map = getIndexingMapAttr().getMap();
AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap();
p << " " << getIndexingMapAttr();

auto operands = getOperands();
Expand All @@ -214,14 +214,14 @@ void ApplyIndexingOp::print(OpAsmPrinter& p) {
}

LogicalResult ApplyIndexingOp::verify() {
auto affine_map = getIndexingMapAttr().getMap();
auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap();
unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
if (getOperands().size() != num_variables) {
return emitOpError(
"operand count must match the number of dimensions and symbols in the "
"affine map");
}
if (!getIndexingMapAttr().getConstraints().empty()) {
if (!getIndexingMap().GetConstraints().empty()) {
return emitOpError("apply indexing op cannot have any constraints");
}
return success();
Expand Down Expand Up @@ -310,11 +310,10 @@ struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
PatternRewriter& rewriter) const override {
IndexingMap indexing_map = indexing_op.getIndexingMap();
if (indexing_map.IsSimplified()) {
if (!indexing_map.Simplify()) {
return rewriter.notifyMatchFailure(indexing_op,
"IndexingMap is already simplified");
}
indexing_map.Simplify();
rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
indexing_op, indexing_op.getOperands(), indexing_map);
return success();
Expand Down Expand Up @@ -1046,12 +1045,12 @@ LogicalResult MaterializeOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult InsertOp::verify() {
if (!getMap().getRangeVars().empty()) {
if (!getMap().getIndexingMap().GetRangeVars().empty()) {
return emitOpError() << "insert_op map must not have any symbols";
}
int64_t vector_map_num_results =
getSource().getType().getIndexingMapAttr().getNumResults();
if (vector_map_num_results != getMap().getDimVars().size()) {
if (vector_map_num_results != getMap().getIndexingMap().GetDimVars().size()) {
return emitOpError() << "source map result count must equal insert_op's "
"map's dimension count";
}
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/legacy/concatenate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ TEST_F(ConcatenateTest, ThreadIndexing) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
bl_x * 128 + th_x in [0, 399],
is_simplified: true
bl_x * 128 + th_x in [0, 399]
)";
EXPECT_THAT(
ToString(*fusion->ComputeThreadIdToInputIndexing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
bl_y in [0, 0],
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
is_simplified: true
unroll_id in [0, 0]
)"));
auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing(
/*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/legacy/input_slices_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
bl_x * 128 + th_x in [0, 29],
is_simplified: true
bl_x * 128 + th_x in [0, 29]
)"));
}

Expand Down
15 changes: 5 additions & 10 deletions xla/service/gpu/fusions/legacy/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 3],
bl_x * 128 + th_x in [0, 1499999],
is_simplified: true
bl_x * 128 + th_x in [0, 1499999]
)"));
}

Expand Down Expand Up @@ -133,8 +132,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
bl_y in [0, 0],
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
is_simplified: true
unroll_id in [0, 0]
)"));
auto thread_id_to_input_indexing =
loop_fusion->ComputeThreadIdToInputIndexing(
Expand All @@ -152,8 +150,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
bl_y in [0, 0],
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
is_simplified: true
unroll_id in [0, 0]
)"));
}

Expand Down Expand Up @@ -196,8 +193,7 @@ TEST_F(LoopTest, Broadcast) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
bl_x * 128 + th_x in [0, 5999],
is_simplified: true
bl_x * 128 + th_x in [0, 5999]
)"));
auto thread_id_to_input_indexing =
loop_fusion->ComputeThreadIdToInputIndexing(
Expand All @@ -217,8 +213,7 @@ TEST_F(LoopTest, Broadcast) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
bl_x * 128 + th_x in [0, 5999],
is_simplified: true
bl_x * 128 + th_x in [0, 5999]
)"));
}

Expand Down
6 changes: 2 additions & 4 deletions xla/service/gpu/fusions/legacy/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) {
s0 in [0, 0],
s1 in [0, 0],
s2 in [0, 7],
s3 in [0, 1],
is_simplified: true
s3 in [0, 1]
)"));
EXPECT_THAT(
ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)),
Expand All @@ -103,8 +102,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) {
d3 in [0, 799],
d4 in [0, 0],
d5 in [0, 0],
d0 mod 32 in [0, 0],
is_simplified: true
d0 mod 32 in [0, 0]
)"));
}

Expand Down
6 changes: 2 additions & 4 deletions xla/service/gpu/fusions/legacy/scatter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) {
bl_z in [0, 0],
chunk_id in [0, 0],
unroll_id in [0, 0],
bl_x * 128 + th_x in [0, 8399],
is_simplified: true
bl_x * 128 + th_x in [0, 8399]
)";
mlir::SmallVector<std::string> dim_names = {"th_x", "th_y", "th_z",
"bl_x", "bl_y", "bl_z"};
Expand Down Expand Up @@ -197,8 +196,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) {
chunk_id in [0, 0],
unroll_id in [0, 0],
index_id in [0, 0],
bl_x * 128 + th_x in [0, 8399],
is_simplified: true
bl_x * 128 + th_x in [0, 8399]
)";
EXPECT_THAT(
ToString(*fusion->ComputeThreadIdToInputIndexing(
Expand Down
Loading

0 comments on commit d539254

Please sign in to comment.