diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index c3ccd4c4037ce..7df2be6b69681 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" -#include - #include #include #include @@ -3542,335 +3540,338 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, outer_rhs_dot = true; } - if ((outer_lhs_dot || outer_rhs_dot) && - !Cast(inner)->sparse_operands()) { - DotDimensionNumbers ab_dnums, ac_dnums, bc_dnums; - - // We will now use inner and outer to build up ab_dnums, ac_dnums, and - // bc_dnums. One of these three comes for free from inner - if (outer_lhs_dot) { - ab_dnums = inner->dot_dimension_numbers(); - } else if (outer_rhs_dot) { - bc_dnums = inner->dot_dimension_numbers(); - } - - // For the other two, it's more complicated. First, we construct maps from - // the dimensions of inner to the dimensions of inner's operands - std::vector map_inner_lhs, map_inner_rhs; - std::tie(map_inner_lhs, map_inner_rhs) = ConstructFromDotMaps( - inner, inner->operand(0)->shape(), inner->operand(1)->shape()); - DotDimensionNumbers outer_dnums = outer->dot_dimension_numbers(); - - // We now iterate through the batch dimensions of outer, and recover - // the batch dimensions shared between each operand of inner and the - // other operand of outer - for (int64_t i = 0; i < outer_dnums.lhs_batch_dimensions_size(); ++i) { - // First we retrieve inner_index and other_index depending on which side - // of outer that inner is on - int64_t inner_index, other_index; - if (outer_lhs_dot) { - inner_index = outer_dnums.lhs_batch_dimensions(i); - other_index = outer_dnums.rhs_batch_dimensions(i); - } else { - inner_index = outer_dnums.rhs_batch_dimensions(i); - other_index = outer_dnums.lhs_batch_dimensions(i); - } - - auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix, - int64_t rhs_ix) { - dnums.add_lhs_batch_dimensions(lhs_ix); - dnums.add_rhs_batch_dimensions(rhs_ix); - }; + if ((!outer_lhs_dot && !outer_rhs_dot)) { + return RewriteResult::kNoRewrite; + } - for (auto& map : {map_inner_lhs, map_inner_rhs}) { - int64_t mapped_index = map[inner_index]; - if (mapped_index != -1) { - // Whether the mapped value is the lhs or rhs of the new dnums - // depends on whether inner is the lhs or rhs operand of outer. The - // dnums itself depends on this and also on which map we are - // iterating through - if (outer_lhs_dot) { - add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums, - mapped_index, other_index); - } else { - add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums, - other_index, mapped_index); - } - } - } - } + if (Cast(inner)->sparse_operands()) { + return RewriteResult::kNoRewrite; + } - // We now do the same thing for the contracting dimensions of outer - for (int64_t i = 0; i < outer_dnums.lhs_contracting_dimensions_size(); - ++i) { - // First we retrieve inner_index and other_index depending on which side - // of outer that inner is on - int64_t inner_index, other_index; - if (outer_lhs_dot) { - inner_index = outer_dnums.lhs_contracting_dimensions(i); - other_index = outer_dnums.rhs_contracting_dimensions(i); - } else { - inner_index = outer_dnums.rhs_contracting_dimensions(i); - other_index = outer_dnums.lhs_contracting_dimensions(i); - } + DotDimensionNumbers ab_dnums, ac_dnums, bc_dnums; - // Once we have the inner_index, we determine whether this index - // corresponds to a dimension coming from the lhs or rhs of inner - bool from_inner_lhs = map_inner_lhs[inner_index] != -1; - bool from_inner_rhs = map_inner_rhs[inner_index] != -1; + // We will now use inner and outer to build up ab_dnums, ac_dnums, and + // bc_dnums. One of these three comes for free from inner + if (outer_lhs_dot) { + ab_dnums = inner->dot_dimension_numbers(); + } else if (outer_rhs_dot) { + bc_dnums = inner->dot_dimension_numbers(); + } - // If a dimension of inner is the result of batching and it is - // contracted in outer, we stop trying to reorder - if (from_inner_lhs && from_inner_rhs) { - return RewriteResult::kStopRewrites; - } + // For the other two, it's more complicated. First, we construct maps from + // the dimensions of inner to the dimensions of inner's operands + std::vector map_inner_lhs, map_inner_rhs; + std::tie(map_inner_lhs, map_inner_rhs) = ConstructFromDotMaps( + inner, inner->operand(0)->shape(), inner->operand(1)->shape()); + DotDimensionNumbers outer_dnums = outer->dot_dimension_numbers(); - // The map we use depends on which operand of inner this dim comes from - std::vector map; - if (from_inner_lhs) { - map = map_inner_lhs; - } else { - map = map_inner_rhs; - } + // We now iterate through the batch dimensions of outer, and recover + // the batch dimensions shared between each operand of inner and the + // other operand of outer + for (int64_t i = 0; i < outer_dnums.lhs_batch_dimensions_size(); ++i) { + // First we retrieve inner_index and other_index depending on which side + // of outer that inner is on + int64_t inner_index, other_index; + if (outer_lhs_dot) { + inner_index = outer_dnums.lhs_batch_dimensions(i); + other_index = outer_dnums.rhs_batch_dimensions(i); + } else { + inner_index = outer_dnums.rhs_batch_dimensions(i); + other_index = outer_dnums.lhs_batch_dimensions(i); + } - // Whether the mapped value goes into the lhs or rhs of the new dnums - // depends on whether inner was the lhs or rhs operand of outer - int64_t lhs_index, rhs_index; - if (outer_lhs_dot) { - lhs_index = map[inner_index]; - rhs_index = other_index; - } else { - lhs_index = other_index; - rhs_index = map[inner_index]; - } + auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix, + int64_t rhs_ix) { + dnums.add_lhs_batch_dimensions(lhs_ix); + dnums.add_rhs_batch_dimensions(rhs_ix); + }; - // Finally, we have to determine which dnums to add to - DotDimensionNumbers* dnums; - if (outer_lhs_dot) { - if (from_inner_lhs) { - dnums = &ac_dnums; - } else { - dnums = &bc_dnums; - } - } else { - if (from_inner_lhs) { - dnums = &ab_dnums; + for (auto& map : {map_inner_lhs, map_inner_rhs}) { + int64_t mapped_index = map[inner_index]; + if (mapped_index != -1) { + // Whether the mapped value is the lhs or rhs of the new dnums + // depends on whether inner is the lhs or rhs operand of outer. The + // dnums itself depends on this and also on which map we are + // iterating through + if (outer_lhs_dot) { + add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums, + mapped_index, other_index); } else { - dnums = &ac_dnums; + add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums, + other_index, mapped_index); } } - - // Add the contracting dimensions - dnums->add_lhs_contracting_dimensions(lhs_index); - dnums->add_rhs_contracting_dimensions(rhs_index); } + } - // ab_dnums, ac_dnums, and bc_dnums are now complete. We can now use these - // dnums to construct the dnums for the new_inner and new_outer. - HloInstruction *new_inner_lhs, *new_inner_rhs; - DotDimensionNumbers new_inner_dnums; + // We now do the same thing for the contracting dimensions of outer + for (int64_t i = 0; i < outer_dnums.lhs_contracting_dimensions_size(); ++i) { + // First we retrieve inner_index and other_index depending on which side + // of outer that inner is on + int64_t inner_index, other_index; if (outer_lhs_dot) { - new_inner_lhs = inner->mutable_operand(1); - new_inner_rhs = outer->mutable_operand(1); - new_inner_dnums = bc_dnums; + inner_index = outer_dnums.lhs_contracting_dimensions(i); + other_index = outer_dnums.rhs_contracting_dimensions(i); } else { - new_inner_lhs = outer->mutable_operand(0); - new_inner_rhs = inner->mutable_operand(0); - new_inner_dnums = ab_dnums; + inner_index = outer_dnums.rhs_contracting_dimensions(i); + other_index = outer_dnums.lhs_contracting_dimensions(i); + } + + // Once we have the inner_index, we determine whether this index + // corresponds to a dimension coming from the lhs or rhs of inner + bool from_inner_lhs = map_inner_lhs[inner_index] != -1; + bool from_inner_rhs = map_inner_rhs[inner_index] != -1; + + // If a dimension of inner is the result of batching and it is + // contracted in outer, we stop trying to reorder + if (from_inner_lhs && from_inner_rhs) { + return RewriteResult::kStopRewrites; } - // For dnums for new_outer, we will need some additional maps - std::vector map_lhs_new_inner, map_rhs_new_inner; - std::tie(map_lhs_new_inner, map_rhs_new_inner) = ConstructToDotMaps( - new_inner_dnums, new_inner_lhs->shape(), new_inner_rhs->shape()); - DotDimensionNumbers new_outer_dnums; + // The map we use depends on which operand of inner this dim comes from + std::vector map; + if (from_inner_lhs) { + map = map_inner_lhs; + } else { + map = map_inner_rhs; + } - // To build up new_outer dnums, we need to combine two "pairs". If the - // inner dot was originally on lhs, these pairs are ab and ac. If the - // inner dot was originally on the rhs, these pairs ac and bc - std::vector dnums_to_reorder; + // Whether the mapped value goes into the lhs or rhs of the new dnums + // depends on whether inner was the lhs or rhs operand of outer + int64_t lhs_index, rhs_index; if (outer_lhs_dot) { - dnums_to_reorder.push_back(ab_dnums); - dnums_to_reorder.push_back(ac_dnums); + lhs_index = map[inner_index]; + rhs_index = other_index; } else { - dnums_to_reorder.push_back(ac_dnums); - dnums_to_reorder.push_back(bc_dnums); + lhs_index = other_index; + rhs_index = map[inner_index]; } - // We now iterate through the batch and contracting dimensions of each - // pair, using the previously constructed maps to add to new_outer dnums - for (int pair = 0; pair < 2; ++pair) { - DotDimensionNumbers dnums = dnums_to_reorder[pair]; - std::vector map = - (pair % 2) == 0 ? map_lhs_new_inner : map_rhs_new_inner; + // Finally, we have to determine which dnums to add to + DotDimensionNumbers* dnums; + if (outer_lhs_dot) { + if (from_inner_lhs) { + dnums = &ac_dnums; + } else { + dnums = &bc_dnums; + } + } else { + if (from_inner_lhs) { + dnums = &ab_dnums; + } else { + dnums = &ac_dnums; + } + } - for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { - int64_t new_inner_index, other_index; - if (outer_lhs_dot) { - new_inner_index = dnums.rhs_batch_dimensions(i); - other_index = dnums.lhs_batch_dimensions(i); - } else { - new_inner_index = dnums.lhs_batch_dimensions(i); - other_index = dnums.rhs_batch_dimensions(i); - } + // Add the contracting dimensions + dnums->add_lhs_contracting_dimensions(lhs_index); + dnums->add_rhs_contracting_dimensions(rhs_index); + } - int64_t lhs_index, rhs_index; - if (outer_lhs_dot) { - lhs_index = other_index; - rhs_index = map[new_inner_index]; - } else { - lhs_index = map[new_inner_index]; - rhs_index = other_index; - } + // ab_dnums, ac_dnums, and bc_dnums are now complete. We can now use these + // dnums to construct the dnums for the new_inner and new_outer. + HloInstruction *new_inner_lhs, *new_inner_rhs; + DotDimensionNumbers new_inner_dnums; + if (outer_lhs_dot) { + new_inner_lhs = inner->mutable_operand(1); + new_inner_rhs = outer->mutable_operand(1); + new_inner_dnums = bc_dnums; + } else { + new_inner_lhs = outer->mutable_operand(0); + new_inner_rhs = inner->mutable_operand(0); + new_inner_dnums = ab_dnums; + } + + // For dnums for new_outer, we will need some additional maps + std::vector map_lhs_new_inner, map_rhs_new_inner; + std::tie(map_lhs_new_inner, map_rhs_new_inner) = ConstructToDotMaps( + new_inner_dnums, new_inner_lhs->shape(), new_inner_rhs->shape()); + DotDimensionNumbers new_outer_dnums; + + // To build up new_outer dnums, we need to combine two "pairs". If the + // inner dot was originally on lhs, these pairs are ab and ac. If the + // inner dot was originally on the rhs, these pairs ac and bc + std::vector dnums_to_reorder; + if (outer_lhs_dot) { + dnums_to_reorder.push_back(ab_dnums); + dnums_to_reorder.push_back(ac_dnums); + } else { + dnums_to_reorder.push_back(ac_dnums); + dnums_to_reorder.push_back(bc_dnums); + } - if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(), - lhs_index)) { - new_outer_dnums.add_lhs_batch_dimensions(lhs_index); - new_outer_dnums.add_rhs_batch_dimensions(rhs_index); - } + // We now iterate through the batch and contracting dimensions of each + // pair, using the previously constructed maps to add to new_outer dnums + for (int pair = 0; pair < 2; ++pair) { + DotDimensionNumbers dnums = dnums_to_reorder[pair]; + std::vector map = + (pair % 2) == 0 ? map_lhs_new_inner : map_rhs_new_inner; + + for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + int64_t new_inner_index, other_index; + if (outer_lhs_dot) { + new_inner_index = dnums.rhs_batch_dimensions(i); + other_index = dnums.lhs_batch_dimensions(i); + } else { + new_inner_index = dnums.lhs_batch_dimensions(i); + other_index = dnums.rhs_batch_dimensions(i); } - for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { - int64_t new_inner_index, other_index; - if (outer_lhs_dot) { - new_inner_index = dnums.rhs_contracting_dimensions(i); - other_index = dnums.lhs_contracting_dimensions(i); - } else { - new_inner_index = dnums.lhs_contracting_dimensions(i); - other_index = dnums.rhs_contracting_dimensions(i); - } - int64_t lhs_index, rhs_index; - if (outer_lhs_dot) { - lhs_index = other_index; - rhs_index = map[new_inner_index]; - } else { - lhs_index = map[new_inner_index]; - rhs_index = other_index; - } + int64_t lhs_index, rhs_index; + if (outer_lhs_dot) { + lhs_index = other_index; + rhs_index = map[new_inner_index]; + } else { + lhs_index = map[new_inner_index]; + rhs_index = other_index; + } - new_outer_dnums.add_lhs_contracting_dimensions(lhs_index); - new_outer_dnums.add_rhs_contracting_dimensions(rhs_index); + if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(), + lhs_index)) { + new_outer_dnums.add_lhs_batch_dimensions(lhs_index); + new_outer_dnums.add_rhs_batch_dimensions(rhs_index); } } + for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + int64_t new_inner_index, other_index; + if (outer_lhs_dot) { + new_inner_index = dnums.rhs_contracting_dimensions(i); + other_index = dnums.lhs_contracting_dimensions(i); + } else { + new_inner_index = dnums.lhs_contracting_dimensions(i); + other_index = dnums.rhs_contracting_dimensions(i); + } - // Get Shape for new_inner - TF_ASSIGN_OR_RETURN( - Shape new_inner_shape, - ShapeInference::InferDotOpShape(new_inner_lhs->shape(), - new_inner_rhs->shape(), new_inner_dnums, - new_inner_lhs->shape().element_type())); - Shape new_outer_lhs_shape = - outer_lhs_dot ? inner->operand(0)->shape() : new_inner_shape; - - // Use HloCostAnalysis to compute flops for both the original and - // reordered instructions, and reorder if doing so decreases flops by a - // factor of the reordering threshold. - const int64_t old_flops = - HloCostAnalysis::GetDotFlops(inner->operand(0)->shape(), inner->shape(), - inner->dot_dimension_numbers()) + - HloCostAnalysis::GetDotFlops(outer->operand(0)->shape(), outer->shape(), - outer_dnums); - const int64_t new_flops = - HloCostAnalysis::GetDotFlops(new_inner_lhs->shape(), new_inner_shape, - new_inner_dnums) + - HloCostAnalysis::GetDotFlops(new_outer_lhs_shape, outer->shape(), - new_outer_dnums); - - if (old_flops / static_cast(new_flops) > - options_.associative_reordering_threshold()) { - // We can now make the Hlo for new_inner and new_outer - TF_ASSIGN_OR_RETURN( - new_inner, - MakeDotHlo(new_inner_lhs, new_inner_rhs, new_inner_dnums, - dot->precision_config(), dot->shape().element_type())); - HloInstruction *new_outer_lhs, *new_outer_rhs; + int64_t lhs_index, rhs_index; if (outer_lhs_dot) { - new_outer_lhs = inner->mutable_operand(0); - new_outer_rhs = new_inner; + lhs_index = other_index; + rhs_index = map[new_inner_index]; } else { - new_outer_lhs = new_inner; - new_outer_rhs = inner->mutable_operand(1); + lhs_index = map[new_inner_index]; + rhs_index = other_index; } - TF_ASSIGN_OR_RETURN( - new_outer, - MakeDotHlo(new_outer_lhs, new_outer_rhs, new_outer_dnums, - dot->precision_config(), dot->shape().element_type())); - - // Depending on the batch dimensions of the original instruction, - // reordering may permute the dimensions of the shape. To correct for - // this, we build a map from old_outer dimensions to new_outer - // dimensions and use it to transpose new_outer. - DimensionVector permutation(new_outer->shape().rank()); - - // Construct additional maps to make the permutation - std::vector map_outer_lhs, map_outer_rhs; - std::tie(map_outer_lhs, map_outer_rhs) = ConstructFromDotMaps( - outer, outer->operand(0)->shape(), outer->operand(1)->shape()); - - std::vector map_outer_inner, map_outer_other; - map_outer_inner = outer_lhs_dot ? map_outer_lhs : map_outer_rhs; - map_outer_other = outer_lhs_dot ? map_outer_rhs : map_outer_lhs; - - std::vector map_inner_new_other; - map_inner_new_other = outer_lhs_dot ? map_inner_lhs : map_inner_rhs; - - std::vector map_other_new_inner; - map_other_new_inner = - outer_lhs_dot ? map_rhs_new_inner : map_lhs_new_inner; - - std::vector map_lhs_new_outer, map_rhs_new_outer; - std::tie(map_lhs_new_outer, map_rhs_new_outer) = - ConstructToDotMaps(new_outer_dnums, new_outer->operand(0)->shape(), - new_outer->operand(1)->shape()); - - std::vector map_new_inner_new_outer, map_new_other_new_outer; - map_new_inner_new_outer = - outer_lhs_dot ? map_rhs_new_outer : map_lhs_new_outer; - map_new_other_new_outer = - outer_lhs_dot ? map_lhs_new_outer : map_rhs_new_outer; - - // Create permutation to do the transpose - bool add_transpose = false; - for (int64_t i = 0; i < outer->shape().rank(); i++) { - int64_t new_outer_index; - if (map_outer_other[i] == -1) { - int64_t inner_index = map_outer_inner[i]; - if (map_inner_new_other[inner_index] == -1) { - int64_t new_inner_index; - if (outer_lhs_dot) { - new_inner_index = map_lhs_new_inner[map_inner_rhs[inner_index]]; - } else { - new_inner_index = map_rhs_new_inner[map_inner_lhs[inner_index]]; - } - new_outer_index = map_new_inner_new_outer[new_inner_index]; + + new_outer_dnums.add_lhs_contracting_dimensions(lhs_index); + new_outer_dnums.add_rhs_contracting_dimensions(rhs_index); + } + } + + // Get Shape for new_inner + TF_ASSIGN_OR_RETURN( + Shape new_inner_shape, + ShapeInference::InferDotOpShape(new_inner_lhs->shape(), + new_inner_rhs->shape(), new_inner_dnums, + new_inner_lhs->shape().element_type())); + Shape new_outer_lhs_shape = + outer_lhs_dot ? inner->operand(0)->shape() : new_inner_shape; + + // Use HloCostAnalysis to compute flops for both the original and + // reordered instructions, and reorder if doing so decreases flops by a + // factor of the reordering threshold. + const int64_t old_flops = + HloCostAnalysis::GetDotFlops(inner->operand(0)->shape(), inner->shape(), + inner->dot_dimension_numbers()) + + HloCostAnalysis::GetDotFlops(outer->operand(0)->shape(), outer->shape(), + outer_dnums); + const int64_t new_flops = + HloCostAnalysis::GetDotFlops(new_inner_lhs->shape(), new_inner_shape, + new_inner_dnums) + + HloCostAnalysis::GetDotFlops(new_outer_lhs_shape, outer->shape(), + new_outer_dnums); + + if (old_flops / static_cast(new_flops) > + options_.associative_reordering_threshold()) { + // We can now make the Hlo for new_inner and new_outer + TF_ASSIGN_OR_RETURN( + new_inner, + MakeDotHlo(new_inner_lhs, new_inner_rhs, new_inner_dnums, + dot->precision_config(), dot->shape().element_type())); + HloInstruction *new_outer_lhs, *new_outer_rhs; + if (outer_lhs_dot) { + new_outer_lhs = inner->mutable_operand(0); + new_outer_rhs = new_inner; + } else { + new_outer_lhs = new_inner; + new_outer_rhs = inner->mutable_operand(1); + } + TF_ASSIGN_OR_RETURN( + new_outer, + MakeDotHlo(new_outer_lhs, new_outer_rhs, new_outer_dnums, + dot->precision_config(), dot->shape().element_type())); + + // Depending on the batch dimensions of the original instruction, + // reordering may permute the dimensions of the shape. To correct for + // this, we build a map from old_outer dimensions to new_outer + // dimensions and use it to transpose new_outer. + DimensionVector permutation(new_outer->shape().rank()); + + // Construct additional maps to make the permutation + std::vector map_outer_lhs, map_outer_rhs; + std::tie(map_outer_lhs, map_outer_rhs) = ConstructFromDotMaps( + outer, outer->operand(0)->shape(), outer->operand(1)->shape()); + + std::vector map_outer_inner, map_outer_other; + map_outer_inner = outer_lhs_dot ? map_outer_lhs : map_outer_rhs; + map_outer_other = outer_lhs_dot ? map_outer_rhs : map_outer_lhs; + + std::vector map_inner_new_other; + map_inner_new_other = outer_lhs_dot ? map_inner_lhs : map_inner_rhs; + + std::vector map_other_new_inner; + map_other_new_inner = outer_lhs_dot ? map_rhs_new_inner : map_lhs_new_inner; + + std::vector map_lhs_new_outer, map_rhs_new_outer; + std::tie(map_lhs_new_outer, map_rhs_new_outer) = + ConstructToDotMaps(new_outer_dnums, new_outer->operand(0)->shape(), + new_outer->operand(1)->shape()); + + std::vector map_new_inner_new_outer, map_new_other_new_outer; + map_new_inner_new_outer = + outer_lhs_dot ? map_rhs_new_outer : map_lhs_new_outer; + map_new_other_new_outer = + outer_lhs_dot ? map_lhs_new_outer : map_rhs_new_outer; + + // Create permutation to do the transpose + bool add_transpose = false; + for (int64_t i = 0; i < outer->shape().rank(); i++) { + int64_t new_outer_index; + if (map_outer_other[i] == -1) { + int64_t inner_index = map_outer_inner[i]; + if (map_inner_new_other[inner_index] == -1) { + int64_t new_inner_index; + if (outer_lhs_dot) { + new_inner_index = map_lhs_new_inner[map_inner_rhs[inner_index]]; } else { - int64_t new_other_index = map_inner_new_other[inner_index]; - new_outer_index = map_new_other_new_outer[new_other_index]; + new_inner_index = map_rhs_new_inner[map_inner_lhs[inner_index]]; } + new_outer_index = map_new_inner_new_outer[new_inner_index]; } else { - // Dimension i in outer comes from other - int64_t other_index = map_outer_other[i]; - new_outer_index = - map_new_inner_new_outer[map_other_new_inner[other_index]]; - } - permutation[i] = new_outer_index; - if (i != new_outer_index) { - add_transpose = true; + int64_t new_other_index = map_inner_new_other[inner_index]; + new_outer_index = map_new_other_new_outer[new_other_index]; } - } - - if (add_transpose) { - HloInstruction* transposed_new_outer; - TF_ASSIGN_OR_RETURN(transposed_new_outer, - MakeTransposeHlo(new_outer, permutation)); - VLOG(10) << "Reordering with associativity and transpose"; - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, transposed_new_outer)); } else { - VLOG(10) << "Reordering with associativity"; - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_outer)); + // Dimension i in outer comes from other + int64_t other_index = map_outer_other[i]; + new_outer_index = + map_new_inner_new_outer[map_other_new_inner[other_index]]; + } + permutation[i] = new_outer_index; + if (i != new_outer_index) { + add_transpose = true; } - return RewriteResult::kRewritten; } + + if (add_transpose) { + HloInstruction* transposed_new_outer; + TF_ASSIGN_OR_RETURN(transposed_new_outer, + MakeTransposeHlo(new_outer, permutation)); + VLOG(10) << "Reordering with associativity and transpose"; + TF_RETURN_IF_ERROR(ReplaceInstruction(dot, transposed_new_outer)); + } else { + VLOG(10) << "Reordering with associativity"; + TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_outer)); + } + return RewriteResult::kRewritten; } return RewriteResult::kNoRewrite; } @@ -3961,11 +3962,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { const bool is_packed_nibble = absl::c_linear_search(dot->precision_config().operand_precision(), PrecisionConfig::PACKED_NIBBLE); - const bool has_precision_config_algorithm = - dot->precision_config().algorithm() != PrecisionConfig::ALG_UNSET; + const bool can_rewrite_dot_with_precision_config_algorithm = + dot->precision_config().algorithm() == PrecisionConfig::ALG_UNSET || + dot->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_F32_F32_F32; // If there are no contracting dimensions, a dot can be rewritten as // mul(broadcast(transpose(x)),broadcast(transpose(y))) - if (!is_packed_nibble && !has_precision_config_algorithm && + if (!is_packed_nibble && can_rewrite_dot_with_precision_config_algorithm && options_.enable_dot_to_multiply_rewrite() && dnums.lhs_contracting_dimensions_size() == 0) { return RewriteAsMultiplyDotWithZeroLhsContractingDim(dot, lhs, rhs, dnums); @@ -3992,7 +3995,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) - if (!is_packed_nibble && !has_precision_config_algorithm && + if (!is_packed_nibble && can_rewrite_dot_with_precision_config_algorithm && options_.enable_dot_strength_reduction() && DotHasOnlyBatchAndContractingOnOneOperand(lhs->shape().rank(), rhs->shape().rank(), dnums) && diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 11b7ef5c3abab..41d11484d78dc 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -9484,6 +9484,21 @@ TEST_F(AlgebraicSimplifierTest, ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, + DotToMultiplyRewriteWith_F32_F32_F32_Algorithm) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY dot { + a = f32[128]{0} parameter(0) + b = f32[128]{0} parameter(1) + ROOT dot = f32[128,128]{1,0} dot(a, b), + algorithm=dot_f32_f32_f32 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, RemainderOfIota) { const char* kModuleStr = R"( HloModule m @@ -10445,6 +10460,23 @@ TEST_F(AlgebraicSimplifierTest, ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, + DotStrengthReductionWith_F32_F32_F32_Algorithm) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY dot { + a = f32[128,2]{1,0} parameter(0) + b = f32[2]{0} parameter(1) + ROOT dot = f32[128]{0} dot(a, b), + lhs_contracting_dims={1}, + rhs_contracting_dims={0}, + algorithm=dot_f32_f32_f32 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) { const char* kModuleStr = R"( HloModule m