From 011c02f61f781a4febb8bf93ccbc6e593d46376d Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Thu, 26 Sep 2024 15:08:23 -0700 Subject: [PATCH] Add a helper method to HloTestBase to run a pass on a parameterized HLO string. This is a common pattern in HLO transformation tests, and it's useful to have a helper method to reduce boilerplate. This CL also updates all_reduce_folder_test.cc to use the new helper method. PiperOrigin-RevId: 679309211 --- xla/service/BUILD | 2 +- xla/service/all_reduce_folder_test.cc | 284 +++++++++++--------------- xla/tests/BUILD | 9 + xla/tests/hlo_test_base.cc | 28 ++- xla/tests/hlo_test_base.h | 21 ++ 5 files changed, 174 insertions(+), 170 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index a122a232e032d..ff449c30544fd 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -276,8 +276,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/all_reduce_folder_test.cc b/xla/service/all_reduce_folder_test.cc index e984d089adb19..f23d1f7bdf097 100644 --- a/xla/service/all_reduce_folder_test.cc +++ b/xla/service/all_reduce_folder_test.cc @@ -16,12 +16,10 @@ limitations under the License. #include "xla/service/all_reduce_folder.h" #include -#include +#include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -29,220 +27,180 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -namespace m = xla::testing::opcode_matchers; +namespace matcher = xla::testing::opcode_matchers; using ::testing::HasSubstr; -class AllReduceFolderTest : public HloTestBase { - public: - absl::StatusOr> RunPass( - absl::string_view hlo_module, bool expect_change) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); - auto changed = AllReduceFolder().Run(module.get()); - if (!changed.ok()) { - return changed.status(); - } - EXPECT_EQ(changed.value(), expect_change); - return absl::StatusOr>(std::move(module)); - } +class AllReduceFolderTest : public HloTestBase {}; - size_t AllReduceCount(std::unique_ptr &module) { - return absl::c_count_if(module->entry_computation()->instructions(), - HloPredicateIsOp); - } -}; +const char *k2AllReduce = R"( + HloModule m -TEST_F(AllReduceFolderTest, Simple) { - absl::string_view hlo_string = R"( -HloModule m + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups=$group_0, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups=$group_1, to_apply=sum + } + )"; -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum +size_t AllReduceCount(HloModule *module) { + return absl::c_count_if(module->entry_computation()->instructions(), + HloPredicateIsOp); } -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); + +void ExpectOneAllReduce(HloModule *module, + absl::string_view target_replica_groups) { EXPECT_EQ(AllReduceCount(module), 1); HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); + EXPECT_THAT(root, matcher::AllReduce(matcher::Parameter(0))); + EXPECT_THAT(root->ToString(), HasSubstr(target_replica_groups)); } -// Same as Simple, but groups for the 2 all-reduce's are swapped. -TEST_F(AllReduceFolderTest, SimpleSwap) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) +TEST_F(AllReduceFolderTest, Simple) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_0", "{{0,1},{2,3}}"}, + {"$group_1", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,2},{1,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); +// Same as Simple, but groups for the 2 all-reduce's are swapped. +TEST_F(AllReduceFolderTest, SimpleSwap) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_1", "{{0,1},{2,3}}"}, + {"$group_0", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); } -TEST_F(AllReduceFolderTest, EmptyReplicaGroups) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) +TEST_F(AllReduceFolderTest, BothEmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{}"}})); } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); +TEST_F(AllReduceFolderTest, EmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite( + k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{{0,2},{1,3}}"}})); } -TEST_F(AllReduceFolderTest, MismatchOtherProperties0) { +TEST_F(AllReduceFolderTest, MismatchOtherProperties0_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } -TEST_F(AllReduceFolderTest, MismatchOtherProperties1) { +TEST_F(AllReduceFolderTest, MismatchOtherProperties1_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -mul { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT mul = f32[] multiply(a, b) -} + mul { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT mul = f32[] multiply(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } -TEST_F(AllReduceFolderTest, NotFoldable) { +TEST_F(AllReduceFolderTest, NotFoldable_NotTransformed) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); } TEST_F(AllReduceFolderTest, Foldable0) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum -} -)"; + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,4,5},{2,3,6,7}}")); + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,4,5},{2,3,6,7}}"); } // Verify that a chain of foldable all-reduce's folds in a single pass // invocation. TEST_F(AllReduceFolderTest, FoldableChain) { absl::string_view hlo_string = R"( -HloModule m + HloModule m -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum - ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum - ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum -} -)"; + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum + ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum + ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - std::cerr << module->ToString(); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,2,3,4,5,6,7}}")); + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3,4,5,6,7}}"); } } // namespace diff --git a/xla/tests/BUILD b/xla/tests/BUILD index aa81fe0e5b743..1561843df8a7c 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -179,15 +179,19 @@ cc_library( ":test_utils", ":verified_hlo_module", "//xla:debug_options_flags", + "//xla:error_spec", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:computation_layout", + "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", "//xla/service:hlo_parser", "//xla/service:hlo_runner", @@ -200,12 +204,17 @@ cc_library( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index dbbcd5f866e92..28985fd2ba33b 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include @@ -28,19 +28,19 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/utils/hlo_query.h" -#include "xla/layout_util.h" #include "xla/service/hlo_module_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" #include "xla/service/platform_util.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" @@ -49,9 +49,9 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -323,6 +323,21 @@ void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( } } +absl::StatusOr> HloTestBase::RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change, FixedMapping params) { + std::string hlo_string = absl::StrReplaceAll(hlo_template, params); + SCOPED_TRACE("Input HLO: " + hlo_string); + VLOG(7) << "Input HLO: " << hlo_string; + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); + VLOG(7) << "Output HLO: " + << module->ToString(HloPrintOptions::ShortParsable()); + EXPECT_EQ(changed, expect_change); + return module; +} + absl::StatusOr HloTestBase::Execute( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes) { @@ -665,7 +680,8 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModulesReplicated( auto num_args = module_0->entry_computation()->num_parameters(); if (num_args != options.arguments.size()) { return ::testing::AssertionFailure() - << "Mismatch in number of arguments passed while running replicated " + << "Mismatch in number of arguments passed while running " + "replicated " "hlo module. Expected: " << num_args << ", actual: " << options.arguments.size(); } diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index e075d7fd7123a..c312fc35090d7 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -16,21 +16,31 @@ limitations under the License. #ifndef XLA_TESTS_HLO_TEST_BASE_H_ #define XLA_TESTS_HLO_TEST_BASE_H_ +#include #include +#include #include #include #include #include #include +#include "absl/base/attributes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/computation_layout.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" @@ -187,6 +197,17 @@ class HloTestBase : public ::testing::Test { HloPassInterface&& hlo_pass, std::optional> expected); + using FixedMapping = + std::initializer_list>; + + // Creates an HLO module from a template and an optional replacement map and + // runs the given hlo_pass on the module. Validates whether the pass has + // changed the module or not based on expect_change flag. Returns unique_ptr + // to the HLO module for further inspection. + absl::StatusOr> RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change = true, FixedMapping params = {}); + // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file.