diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 705b126fee4ff..d96db3f920e3c 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -40,6 +40,7 @@ cc_library( "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status_macros", "//xla:types", "//xla:util", diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index fd74a10be12eb..2b7547b2ea35b 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/service/spmd/spmd_partitioner_util.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/window_util.h" @@ -1905,8 +1906,12 @@ absl::StatusOr PartitionBaseCase( hlo.hlo()->opcode() == HloOpcode::kBitcast || hlo.hlo()->opcode() == HloOpcode::kTranspose; }; - bool should_skip_windowed_einsum = false; - if (options.disable_ag_rewrite_for_multiple_consumers) { + const auto& attrs = original_hlo->frontend_attributes().map(); + bool should_skip_windowed_einsum = + attrs.contains(kXlaCollectiveMatmulAttr) && + attrs.at(kXlaCollectiveMatmulAttr) == kXlaCollectiveMatmulNone; + if (!should_skip_windowed_einsum && + options.disable_ag_rewrite_for_multiple_consumers) { auto lhs_operand = has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo(); auto rhs_operand = diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 07a317ca0d662..1288357f5ad37 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -4975,6 +4975,30 @@ ENTRY entry { op::Shape("f32[16,256,1024]"))); } +TEST_P(SpmdPartitioningTest, DisableWindowedEinsumWithUserAnnotation) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2048,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1} + %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1} + ROOT %dot.224 = f32[2048,2176]{1,0} dot(f32[2048,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}, frontend_attributes={_xla_collective_matmul="none"} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/true, + /*choose_faster_windowed_einsum=*/false, + /*unroll_windowed_einsum=*/false, + /*bidirectional_windowed_einsum=*/false, + /*threshold_for_windowed_einsum_mib=*/0)); + ASSERT_FALSE(absl::c_any_of(module->entry_computation()->instructions(), + [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kWhile; + })); +} + TEST_P(SpmdPartitioningTest, EinsumBatchPartitioned) { absl::string_view hlo_string = R"( HloModule module diff --git a/xla/side_effect_util.cc b/xla/side_effect_util.cc index f874bd4a5c6f3..602d76b66a488 100644 --- a/xla/side_effect_util.cc +++ b/xla/side_effect_util.cc @@ -69,6 +69,8 @@ const char kXlaCollectiveMatmulRhsAg[] = "rhs_ag"; const char kXlaCollectiveMatmulRs[] = "rs"; +const char kXlaCollectiveMatmulNone[] = "none"; + const char kXlaMultiRecvCountAttr[] = "_xla_multi_recv_count"; } // namespace xla diff --git a/xla/side_effect_util.h b/xla/side_effect_util.h index 13a74a46d5a00..281a007b4cd8b 100644 --- a/xla/side_effect_util.h +++ b/xla/side_effect_util.h @@ -77,6 +77,7 @@ extern const char kXlaCollectiveMatmulAttr[]; extern const char kXlaCollectiveMatmulLhsAg[]; extern const char kXlaCollectiveMatmulRhsAg[]; extern const char kXlaCollectiveMatmulRs[]; +extern const char kXlaCollectiveMatmulNone[]; // XLA frontend attribute for specifying the number of sends this recv should // match.