Skip to content

Commit

Permalink
Add a helper method to HloTestBase to run a pass on a parameterized H…
Browse files Browse the repository at this point in the history
…LO string.

This is a common pattern in HLO transformation tests, and it's useful to have a helper method to reduce boilerplate.
This CL also updates all_reduce_folder_test.cc to use the new helper method.

PiperOrigin-RevId: 679309211
  • Loading branch information
toli-y authored and Google-ML-Automation committed Sep 26, 2024
1 parent 95beb0e commit 011c02f
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 170 deletions.
2 changes: 1 addition & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ xla_cc_test(
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:statusor",
],
Expand Down
284 changes: 121 additions & 163 deletions xla/service/all_reduce_folder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,233 +16,191 @@ limitations under the License.
#include "xla/service/all_reduce_folder.h"

#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <memory>
#include <utility>

#include "absl/algorithm/container.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

namespace m = xla::testing::opcode_matchers;
namespace matcher = xla::testing::opcode_matchers;
using ::testing::HasSubstr;

class AllReduceFolderTest : public HloTestBase {
public:
absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
absl::string_view hlo_module, bool expect_change) {
TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module));
auto changed = AllReduceFolder().Run(module.get());
if (!changed.ok()) {
return changed.status();
}
EXPECT_EQ(changed.value(), expect_change);
return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
}
class AllReduceFolderTest : public HloTestBase {};

size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
return absl::c_count_if(module->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kAllReduce>);
}
};
const char *k2AllReduce = R"(
HloModule m
TEST_F(AllReduceFolderTest, Simple) {
absl::string_view hlo_string = R"(
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups=$group_0, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups=$group_1, to_apply=sum
}
)";

ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum
size_t AllReduceCount(HloModule *module) {
return absl::c_count_if(module->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kAllReduce>);
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/true));

void ExpectOneAllReduce(HloModule *module,
absl::string_view target_replica_groups) {
EXPECT_EQ(AllReduceCount(module), 1);
HloInstruction *root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, m::AllReduce(m::Parameter(0)));
EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}"));
EXPECT_THAT(root, matcher::AllReduce(matcher::Parameter(0)));
EXPECT_THAT(root->ToString(), HasSubstr(target_replica_groups));
}

// Same as Simple, but groups for the 2 all-reduce's are swapped.
TEST_F(AllReduceFolderTest, SimpleSwap) {
absl::string_view hlo_string = R"(
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
TEST_F(AllReduceFolderTest, Simple) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true,
{{"$group_0", "{{0,1},{2,3}}"},
{"$group_1", "{{0,2},{1,3}}"}}));
ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}");
}

ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,2},{1,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/true));
EXPECT_EQ(AllReduceCount(module), 1);
HloInstruction *root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, m::AllReduce(m::Parameter(0)));
EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}"));
// Same as Simple, but groups for the 2 all-reduce's are swapped.
TEST_F(AllReduceFolderTest, SimpleSwap) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true,
{{"$group_1", "{{0,1},{2,3}}"},
{"$group_0", "{{0,2},{1,3}}"}}));
ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}");
}

TEST_F(AllReduceFolderTest, EmptyReplicaGroups) {
absl::string_view hlo_string = R"(
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
TEST_F(AllReduceFolderTest, BothEmptyReplicaGroups_NotTransformed) {
TF_ASSERT_OK(RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), false,
{{"$group_0", "{}"}, {"$group_1", "{}"}}));
}

ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/false));
TEST_F(AllReduceFolderTest, EmptyReplicaGroups_NotTransformed) {
TF_ASSERT_OK(RunAndCheckHloRewrite(
k2AllReduce, AllReduceFolder(), false,
{{"$group_0", "{}"}, {"$group_1", "{{0,2},{1,3}}"}}));
}

TEST_F(AllReduceFolderTest, MismatchOtherProperties0) {
TEST_F(AllReduceFolderTest, MismatchOtherProperties0_NotTransformed) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/false));
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum
}
)";
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false));
}

TEST_F(AllReduceFolderTest, MismatchOtherProperties1) {
TEST_F(AllReduceFolderTest, MismatchOtherProperties1_NotTransformed) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
mul {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT mul = f32[] multiply(a, b)
}
mul {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT mul = f32[] multiply(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/false));
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul
}
)";
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false));
}

TEST_F(AllReduceFolderTest, NotFoldable) {
TEST_F(AllReduceFolderTest, NotFoldable_NotTransformed) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/false));
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum
}
)";
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false));
}

TEST_F(AllReduceFolderTest, Foldable0) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum
}
)";
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum
ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/true));
EXPECT_EQ(AllReduceCount(module), 1);
HloInstruction *root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, m::AllReduce(m::Parameter(0)));
EXPECT_THAT(root->ToString(),
HasSubstr("replica_groups={{0,1,4,5},{2,3,6,7}}"));
RunAndCheckHloRewrite(hlo_string, AllReduceFolder()));
ExpectOneAllReduce(module.get(), "replica_groups={{0,1,4,5},{2,3,6,7}}");
}

// Verify that a chain of foldable all-reduce's folds in a single pass
// invocation.
TEST_F(AllReduceFolderTest, FoldableChain) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum
ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum
ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum
}
)";
ENTRY main {
p0 = f32[8] parameter(0)
ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum
ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum
ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/true));
std::cerr << module->ToString();
EXPECT_EQ(AllReduceCount(module), 1);
HloInstruction *root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, m::AllReduce(m::Parameter(0)));
EXPECT_THAT(root->ToString(),
HasSubstr("replica_groups={{0,1,2,3,4,5,6,7}}"));
RunAndCheckHloRewrite(hlo_string, AllReduceFolder()));
ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3,4,5,6,7}}");
}

} // namespace
Expand Down
Loading

0 comments on commit 011c02f

Please sign in to comment.