Skip to content

Commit

Permalink
#sdy Merge XLA CallInliner and ShardyCallInliner.
Browse files Browse the repository at this point in the history
Now that Shardy will now be fully integrated, we should delete the `ShardyCallInliner` and update `CallInliner` to look for what `ShardyCallInliner` checks for. We've had two bugs because of this thus far.

PiperOrigin-RevId: 679653544
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 27, 2024
1 parent b7dbeb6 commit c850404
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 233 deletions.
9 changes: 8 additions & 1 deletion xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -32,6 +33,7 @@ limitations under the License.
#include "xla/service/call_graph.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -159,7 +161,12 @@ CallInliner::Inline(HloInstruction* call) {
}

bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
bool inline_under_shardy =
!(instruction->GetModule()->config().use_shardy_partitioner() &&
(absl::StrContains(instruction->to_apply()->name(), "shmap_body") ||
absl::StartsWith(instruction->to_apply()->name(),
sdy::kManualComputationBodyFuncName)));
return inline_under_shardy && instruction->opcode() == HloOpcode::kCall &&
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation();
}
Expand Down
84 changes: 84 additions & 0 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,89 @@ TEST_F(CallInlinerTest, InlineCompositeCall) {
EXPECT_TRUE((*inst)->frontend_attributes().map().empty());
}

TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3}
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4
%custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6}
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4");
}

// Don't inline when the name starts with "xla.sdy.manual_computation_body".
TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4");
}

// Inliner only checks if the name of the function has
// "xla.sdy.manual_computation_body" a prefix, not if it contains it.
TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// Will be inlined.
EXPECT_TRUE(changed);
}

} // namespace
} // namespace xla
25 changes: 0 additions & 25 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,6 @@ package_group(
],
)

cc_library(
name = "shardy_call_inliner",
srcs = ["shardy_call_inliner.cc"],
hdrs = ["shardy_call_inliner.h"],
deps = [
"//xla/hlo/ir:hlo",
"//xla/service:call_inliner",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "shardy_call_inliner_test",
srcs = ["shardy_call_inliner_test.cc"],
deps = [
":shardy_call_inliner",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/log",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "shardy_xla_pass",
srcs = ["shardy_xla_pass.cc"],
Expand Down
33 changes: 0 additions & 33 deletions xla/service/spmd/shardy/shardy_call_inliner.cc

This file was deleted.

59 changes: 0 additions & 59 deletions xla/service/spmd/shardy/shardy_call_inliner.h

This file was deleted.

115 changes: 0 additions & 115 deletions xla/service/spmd/shardy/shardy_call_inliner_test.cc

This file was deleted.

0 comments on commit c850404

Please sign in to comment.