diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index a879e560c1cd17..cb991b45c0dd7e 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_domain_isolator.h" +#include "xla/service/spmd/shardy/constants.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -159,7 +161,12 @@ CallInliner::Inline(HloInstruction* call) { } bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kCall && + bool inline_under_shardy = + !(instruction->GetModule()->config().use_shardy_partitioner() && + (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || + absl::StartsWith(instruction->to_apply()->name(), + sdy::kManualComputationBodyFuncName))); + return inline_under_shardy && instruction->opcode() == HloOpcode::kCall && !instruction->has_backend_config() && !instruction->parent()->IsAsyncComputation(); } diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index ad6ee73eb14e8a..ca829edd9ea648 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -377,5 +377,89 @@ TEST_F(CallInlinerTest, InlineCompositeCall) { EXPECT_TRUE((*inst)->frontend_attributes().map().empty()); } +TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 + %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); +} + +// Don't inline when the name starts with "xla.sdy.manual_computation_body". +TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); +} + +// Inliner only checks if the name of the function has +// "xla.sdy.manual_computation_body" a prefix, not if it contains it. +TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // Will be inlined. + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index 385e284d739f22..5c98e6e2cc3bb5 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -23,31 +23,6 @@ package_group( ], ) -cc_library( - name = "shardy_call_inliner", - srcs = ["shardy_call_inliner.cc"], - hdrs = ["shardy_call_inliner.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:call_inliner", - "//xla/service/spmd/shardy:constants", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "shardy_call_inliner_test", - srcs = ["shardy_call_inliner_test.cc"], - deps = [ - ":shardy_call_inliner", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - ], -) - cc_library( name = "shardy_xla_pass", srcs = ["shardy_xla_pass.cc"], diff --git a/xla/service/spmd/shardy/shardy_call_inliner.cc b/xla/service/spmd/shardy/shardy_call_inliner.cc deleted file mode 100644 index 2de735c98ecbbc..00000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/shardy_call_inliner.h" - -#include "absl/strings/match.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" -#include "xla/service/spmd/shardy/constants.h" - -namespace xla { - -bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { - return CallInliner::IsInlineableCallOp(instruction) && - !(instruction->GetModule()->config().use_shardy_partitioner() && - (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || - absl::StartsWith(instruction->to_apply()->name(), - sdy::kManualComputationBodyFuncName))); -} - -} // namespace xla diff --git a/xla/service/spmd/shardy/shardy_call_inliner.h b/xla/service/spmd/shardy/shardy_call_inliner.h deleted file mode 100644 index 666e168322b5ab..00000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ -#define XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" - -namespace xla { - -// The same as CallInliner, except as part of -// go/jax-shmap -> `sdy.ManualComputationOp` importing, we require the pattern -// in MHLO: -// ``` -// %shard_arg0_0 = custom_call @Sharding(%0) -// %shard_arg0_1 = custom_call @SPMDFullToShardShape(%shard_arg0_0) -// ... -// %shard_argN_0 = custom_call @Sharding(%N) -// %shard_argN_1 = custom_call @SPMDFullToShardShape(%shard_argN_0) -// -// %shard_result0, ..., %shard_resultN = func.call @shmap_body(%shard_arg0_1, -// ..., -// %shard_argN_1) -// -// %shard_result0_0 = custom_call @Sharding(%shard_result0) -// %shard_result0_1 = custom_call @SPMDShardToFullShape(%shard_result0_0) -// ... -// %shard_resultN_0 = custom_call @Sharding(%shard_resultN) -// %shard_resultN_1 = custom_call @SPMDShardToFullShape(%shard_resultN_0) -// ``` -// We specifically match on the `func.call @shmap_body` since we want to inline -// the body of that function into the `ManualComputationOp` body. So this makes -// sure we inline all functions except for the shmap_body's when using -// Shardy. When Shardy is disabled, then we have the same behavior as -// CallInliner. -class ShardyCallInliner : public CallInliner { - public: - using CallInliner::CallInliner; - absl::string_view name() const override { return "shardy-call-inliner"; } - - bool IsInlineableCallOp(HloInstruction* instruction) const override; -}; - -} // namespace xla - -#endif // XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ diff --git a/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/xla/service/spmd/shardy/shardy_call_inliner_test.cc deleted file mode 100644 index b2055e59d75c34..00000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/shardy_call_inliner.h" - -#include -#include "absl/log/log.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace sdy { - -using ShardyCallInlinerTest = xla::HloTestBase; - -TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 - %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // The single call in the module is not inlined. - EXPECT_FALSE(changed); - - HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); - EXPECT_NE(call, nullptr); - EXPECT_TRUE(call->has_to_apply()); - EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); -} - -// Don't inline when the name starts with "xla.sdy.manual_computation_body". -TEST_F(ShardyCallInlinerTest, ManualComputationBodyNotInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // The single call in the module is not inlined. - EXPECT_FALSE(changed); - - HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); - EXPECT_NE(call, nullptr); - EXPECT_TRUE(call->has_to_apply()); - EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); -} - -// Inliner only checks if the name of the function has -// "xla.sdy.manual_computation_body" a prefix, not if it contains it. -TEST_F(ShardyCallInlinerTest, ManualComputationBodyInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // Will be inlined. - EXPECT_TRUE(changed); -} - -} // namespace sdy -} // namespace xla