Skip to content

Commit

Permalink
Add hlo_sharding_util::MoveAndMergeShardingTiles.
Browse files Browse the repository at this point in the history
Given a tiled sharding, move the tiles from source_dim and merge it into
target_dim. For example, given a sharding with tile assignment `[a, b, c, d,
e]`, source_dim = 1, target_dim = 3, the function will return a sharding with
tile assignment `[a, 1, c, db, e]`.

PiperOrigin-RevId: 694580468
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 8, 2024
1 parent b8ec2c4 commit 7febd1c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
51 changes: 51 additions & 0 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,57 @@ std::optional<int64_t> GetDominantDevice(
return dominant_device;
}

HloSharding MoveAndMergeShardingTiles(const HloSharding& sharding,
int64_t source_dim, int64_t target_dim) {
CHECK(sharding.IsTiled());

CHECK_NE(source_dim, target_dim);
CHECK_GE(source_dim, 0);
CHECK_GE(target_dim, 0);
CHECK_LT(source_dim, sharding.TiledDataRank());
CHECK_LT(target_dim, sharding.TiledDataRank());

// There are 3 steps to move and merge the sharding tiles. Given the sharding
// with tile assignment [a, b, c, d, e], source_dim = 1, target_dim = 3, the
// steps are:
// 1. Reshape the tile assignment to [a, b, c, d, 1, e] by inserting a 1 after
// the target_dim.
// 2. Transpose the tile assignment to [a, 1, c, d, b, e] by swapping the
// source_dim and inserted dim of size 1.
// 3. Reshape the tile assignment to [a, 1, c, db, e] by merging the
// target_dim and the swapped source_dim.

// Step 1. Adding a dummy dim of size 1 after the target_dim.
std::vector<int64_t> ta_dims_1(
sharding.tile_assignment().dimensions().begin(),
sharding.tile_assignment().dimensions().end());
ta_dims_1.insert(ta_dims_1.begin() + target_dim + 1, 1);
TileAssignment new_tile_assignment =
sharding.tile_assignment().Reshape(ta_dims_1);

// Step 2. Transpose the tile assignment to swap the source_dim and the
// inserted dim of size 1.
std::vector<int> permutation(new_tile_assignment.num_dimensions());
absl::c_iota(permutation, 0);
std::swap(permutation[target_dim + 1],
permutation[source_dim + (source_dim < target_dim ? 0 : 1)]);
new_tile_assignment = new_tile_assignment.Transpose(permutation);

// Step 3. Reshape the tile assignment to merge the target_dim and the swapped
// source_dim.
std::vector<int64_t> ta_dims_2(new_tile_assignment.dimensions().begin(),
new_tile_assignment.dimensions().end());
ta_dims_2[target_dim] *= ta_dims_2[target_dim + 1];
ta_dims_2.erase(ta_dims_2.begin() + target_dim + 1);
new_tile_assignment = new_tile_assignment.Reshape(ta_dims_2);

if (sharding.ReplicateOnLastTileDim()) {
return HloSharding::PartialTile(new_tile_assignment, sharding.metadata());
}
return HloSharding::Subgroup(new_tile_assignment, sharding.subgroup_types(),
sharding.metadata());
}

HloSharding TransposeSharding(const HloSharding& sharding,
absl::Span<const int64_t> dimensions) {
if (sharding.IsTileMaximal() || sharding.IsManual()) {
Expand Down
7 changes: 7 additions & 0 deletions xla/hlo/utils/hlo_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ std::optional<int64_t> GetMostOccurringDevice(
std::optional<int64_t> GetDominantDevice(
absl::Span<HloComputation* const> computations, double dominant_factor);

// Given a tiled sharding, move the tiles from source_dim and merge it into
// target_dim. For example, given a sharding with tile assignment [a, b, c, d,
// e], source_dim = 1, target_dim = 3, the function will return a sharding with
// tile assignment [a, 1, c, db, e].
HloSharding MoveAndMergeShardingTiles(const HloSharding& sharding,
int64_t source_dim, int64_t target_dim);

// Returns the HloSharding with the tile dimensions and tile assignment
// transposed based on the specified dimension numbers. In case of a tile
// maximal sharding returns the original sharding.
Expand Down
28 changes: 28 additions & 0 deletions xla/hlo/utils/hlo_sharding_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,34 @@ TEST(HloShardingUtilTest, MergeShardingIfCompatible8) {
HloSharding::Tile(TileAssignment({2, 4}, {2, 2, 2}, {0, 2, 1})));
}

TEST(HloShardingUtilTest, MoveAndMergeShardingTilesPartialTile) {
HloSharding sharding =
HloSharding::PartialTile(TileAssignment({2, 3, 5, 7, 11}));
EXPECT_EQ(MoveAndMergeShardingTiles(sharding, 1, 3),
HloSharding::PartialTile(TileAssignment(
{2, 1, 5, 7 * 3, 11}, {2, 3, 5, 7, 11}, {0, 2, 3, 1, 4})));

EXPECT_EQ(MoveAndMergeShardingTiles(sharding, 3, 1),
HloSharding::PartialTile(TileAssignment(
{2, 3 * 7, 5, 1, 11}, {2, 3, 5, 7, 11}, {0, 1, 3, 2, 4})));
}

TEST(HloShardingUtilTest, MoveAndMergeShardingTilesSubGroup) {
HloSharding sharding =
HloSharding::Subgroup(TileAssignment({2, 3, 5, 7, 11}),
{OpSharding::MANUAL, OpSharding::REPLICATED});
EXPECT_EQ(
MoveAndMergeShardingTiles(sharding, 0, 2),
HloSharding::Subgroup(TileAssignment({1, 3, 5 * 2, 7, 11},
{2, 3, 5, 7, 11}, {1, 2, 0, 3, 4}),
{OpSharding::MANUAL, OpSharding::REPLICATED}));
EXPECT_EQ(
MoveAndMergeShardingTiles(sharding, 2, 0),
HloSharding::Subgroup(TileAssignment({2 * 5, 3, 1, 7, 11},
{2, 3, 5, 7, 11}, {0, 2, 1, 3, 4}),
{OpSharding::MANUAL, OpSharding::REPLICATED}));
}

TEST(HloShardingUtilTest, TransposeShardingReplicated) {
EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}),
HloSharding::Replicate());
Expand Down

0 comments on commit 7febd1c

Please sign in to comment.