Skip to content

Commit

Permalink
[xla] Add a test for HLO deduplication + execution threads
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647816688
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 28, 2024
1 parent 9a6abc3 commit ff21834
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 14 deletions.
11 changes: 11 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,16 @@ cc_library(
hdrs = ["hlo_computation_deduplicator.h"],
deps = [
":hlo_pass",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:logging",
],
)

Expand All @@ -1171,6 +1180,7 @@ xla_cc_test(
":hlo_computation_deduplicator",
":hlo_pass",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:test",
"//xla:types",
Expand All @@ -1181,6 +1191,7 @@ xla_cc_test(
"//xla/tests:xla_internal_test_main",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
8 changes: 7 additions & 1 deletion xla/service/hlo_computation_deduplicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ limitations under the License.

#include "xla/service/hlo_computation_deduplicator.h"

#include <algorithm>
#include <string>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/shape_util.h"
#include "tsl/platform/logging.h"

namespace xla {

Expand All @@ -36,6 +41,7 @@ bool HloComputationDeduplicator::ContainsLargeConstants(HloComputation* comp) {
}
return false;
}

absl::StatusOr<bool> HloComputationDeduplicator::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down
12 changes: 8 additions & 4 deletions xla/service/hlo_computation_deduplicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ limitations under the License.
#ifndef XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_
#define XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/hlo_pass_interface.h"

namespace xla {

// Deduplicate computations inside a `HloModule`: If two computations are
// identical then keep the first one (in postorder terms) and remove the rest.
class HloComputationDeduplicator : public HloModulePass {
private:
bool ContainsLargeConstants(HloComputation* comp);
bool mark_fusion_duplications_;

public:
// Setting mark_fusion_duplications to true will only process fusions in the
// HLO. The comparator in this pass will mark duplicate fusions which is
Expand All @@ -40,6 +40,10 @@ class HloComputationDeduplicator : public HloModulePass {
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
bool ContainsLargeConstants(HloComputation* comp);
bool mark_fusion_duplications_;
};

} // namespace xla
Expand Down
64 changes: 55 additions & 9 deletions xla/service/hlo_computation_deduplicator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.

#include "xla/service/hlo_computation_deduplicator.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
Expand All @@ -27,16 +26,12 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/service/hlo_pass_fix.h"
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
#include "tsl/lib/core/status_test_util.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -100,6 +95,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBandC) {
}
EXPECT_EQ(computation_names.size(), 2);
}

TEST_F(HloComputationDeduplicatorTest, RemoveRegionBExactCopy) {
const std::string_view text = R"(
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
Expand Down Expand Up @@ -181,7 +177,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
rd1 = s32[] call(Arg_0, Arg_1), to_apply=main.15
rd2 = s32[] call(Arg_0, Arg_1), to_apply=main.16
ROOT ret = add(rd1, rd2)
}
}
)";

auto computation_names = RunDeduplicatePass(text, /*expect_true=*/true);
Expand All @@ -195,6 +191,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
}
EXPECT_EQ(computation_names.size(), 3);
}

TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionsWithDifferentSubcomp) {
const std::string_view text = R"(
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
Expand Down Expand Up @@ -334,7 +331,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
)";

auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
// Will also take into account commutativety.
// Will also take into account commutativity.
int region_b_count = 0;
for (auto name : computation_names) {
region_b_count += (name == "region_B");
Expand All @@ -343,6 +340,54 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
EXPECT_EQ(computation_names.size(), 3);
}

TEST_F(HloComputationDeduplicatorTest,
DontRemoveRegionBDifferentExecutionThread) {
const std::string_view text = R"(
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
region_A {
Arg_0 = s32[] parameter(0)
Arg_1 = s32[] parameter(1)
ROOT add = s32[] add(Arg_0, Arg_1)
}
region_B {
Arg_0 = s32[] parameter(0)
Arg_1 = s32[] parameter(1)
ROOT add = s32[] add(Arg_0, Arg_1)
}
called_computation {
Arg_0 = s32[15]{0} parameter(0)
Cst = s32[] constant(0)
ROOT rd2 = s32[] reduce(Arg_0, Cst), dimensions={0}, to_apply=region_B
}, execution_thread="parallel_thread"
ENTRY main.15 {
Arg_0 = s32[10]{0} parameter(0)
constant.3 = s32[] constant(0)
rd1 = s32[] reduce(Arg_0, constant.3), dimensions={0}, to_apply=region_A
Arg_1 = s32[15]{0} parameter(1)
call-start = ((s32[15]{0}), s32[], s32[]) call-start(Arg_1),
async_execution_thread="parallel_thread",
to_apply=%called_computation
call-done = s32[] call-done(call-start)
ROOT multiply.14 = s32[] multiply(rd1, call-done)
}
)";

auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
// Will also take into account commutativity.
int region_b_count = 0;
for (auto name : computation_names) {
region_b_count += (name == "region_B");
}
EXPECT_EQ(region_b_count, 1);
EXPECT_EQ(computation_names.size(), 5);
}

TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionLargeConstant) {
const std::string_view text = R"(
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
Expand Down Expand Up @@ -618,5 +663,6 @@ TEST_F(HloComputationDeduplicatorTest, DontDeduplicateReduceAllReduce) {
auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
EXPECT_EQ(computation_names.size(), 3);
}

} // namespace
} // namespace xla

0 comments on commit ff21834

Please sign in to comment.