diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index 7fc9d7761e26e..def884e17ef0c 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -673,6 +673,57 @@ std::optional GetDominantDevice( return dominant_device; } +HloSharding MoveAndMergeShardingTiles(const HloSharding& sharding, + int64_t source_dim, int64_t target_dim) { + CHECK(sharding.IsTiled()); + + CHECK_NE(source_dim, target_dim); + CHECK_GE(source_dim, 0); + CHECK_GE(target_dim, 0); + CHECK_LT(source_dim, sharding.TiledDataRank()); + CHECK_LT(target_dim, sharding.TiledDataRank()); + + // There are 3 steps to move and merge the sharding tiles. Given the sharding + // with tile assignment [a, b, c, d, e], source_dim = 1, target_dim = 3, the + // steps are: + // 1. Reshape the tile assignment to [a, b, c, d, 1, e] by inserting a 1 after + // the target_dim. + // 2. Transpose the tile assignment to [a, 1, c, d, b, e] by swapping the + // source_dim and inserted dim of size 1. + // 3. Reshape the tile assignment to [a, 1, c, db, e] by merging the + // target_dim and the swapped source_dim. + + // Step 1. Adding a dummy dim of size 1 after the target_dim. + std::vector ta_dims_1( + sharding.tile_assignment().dimensions().begin(), + sharding.tile_assignment().dimensions().end()); + ta_dims_1.insert(ta_dims_1.begin() + target_dim + 1, 1); + TileAssignment new_tile_assignment = + sharding.tile_assignment().Reshape(ta_dims_1); + + // Step 2. Transpose the tile assignment to swap the source_dim and the + // inserted dim of size 1. + std::vector permutation(new_tile_assignment.num_dimensions()); + absl::c_iota(permutation, 0); + std::swap(permutation[target_dim + 1], + permutation[source_dim + (source_dim < target_dim ? 0 : 1)]); + new_tile_assignment = new_tile_assignment.Transpose(permutation); + + // Step 3. Reshape the tile assignment to merge the target_dim and the swapped + // source_dim. + std::vector ta_dims_2(new_tile_assignment.dimensions().begin(), + new_tile_assignment.dimensions().end()); + ta_dims_2[target_dim] *= ta_dims_2[target_dim + 1]; + ta_dims_2.erase(ta_dims_2.begin() + target_dim + 1); + new_tile_assignment = new_tile_assignment.Reshape(ta_dims_2); + + if (sharding.ReplicateOnLastTileDim()) { + return HloSharding::PartialTile(new_tile_assignment, sharding.metadata()); + } + return HloSharding::Subgroup(new_tile_assignment, sharding.subgroup_types(), + sharding.metadata()); +} + HloSharding TransposeSharding(const HloSharding& sharding, absl::Span dimensions) { if (sharding.IsTileMaximal() || sharding.IsManual()) { diff --git a/xla/hlo/utils/hlo_sharding_util.h b/xla/hlo/utils/hlo_sharding_util.h index 3233fa1624549..051af815aec16 100644 --- a/xla/hlo/utils/hlo_sharding_util.h +++ b/xla/hlo/utils/hlo_sharding_util.h @@ -108,6 +108,13 @@ std::optional GetMostOccurringDevice( std::optional GetDominantDevice( absl::Span computations, double dominant_factor); +// Given a tiled sharding, move the tiles from source_dim and merge it into +// target_dim. For example, given a sharding with tile assignment [a, b, c, d, +// e], source_dim = 1, target_dim = 3, the function will return a sharding with +// tile assignment [a, 1, c, db, e]. +HloSharding MoveAndMergeShardingTiles(const HloSharding& sharding, + int64_t source_dim, int64_t target_dim); + // Returns the HloSharding with the tile dimensions and tile assignment // transposed based on the specified dimension numbers. In case of a tile // maximal sharding returns the original sharding. diff --git a/xla/hlo/utils/hlo_sharding_util_test.cc b/xla/hlo/utils/hlo_sharding_util_test.cc index 75c81215cb691..73bd59ac0d6a4 100644 --- a/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/xla/hlo/utils/hlo_sharding_util_test.cc @@ -112,6 +112,34 @@ TEST(HloShardingUtilTest, MergeShardingIfCompatible8) { HloSharding::Tile(TileAssignment({2, 4}, {2, 2, 2}, {0, 2, 1}))); } +TEST(HloShardingUtilTest, MoveAndMergeShardingTilesPartialTile) { + HloSharding sharding = + HloSharding::PartialTile(TileAssignment({2, 3, 5, 7, 11})); + EXPECT_EQ(MoveAndMergeShardingTiles(sharding, 1, 3), + HloSharding::PartialTile(TileAssignment( + {2, 1, 5, 7 * 3, 11}, {2, 3, 5, 7, 11}, {0, 2, 3, 1, 4}))); + + EXPECT_EQ(MoveAndMergeShardingTiles(sharding, 3, 1), + HloSharding::PartialTile(TileAssignment( + {2, 3 * 7, 5, 1, 11}, {2, 3, 5, 7, 11}, {0, 1, 3, 2, 4}))); +} + +TEST(HloShardingUtilTest, MoveAndMergeShardingTilesSubGroup) { + HloSharding sharding = + HloSharding::Subgroup(TileAssignment({2, 3, 5, 7, 11}), + {OpSharding::MANUAL, OpSharding::REPLICATED}); + EXPECT_EQ( + MoveAndMergeShardingTiles(sharding, 0, 2), + HloSharding::Subgroup(TileAssignment({1, 3, 5 * 2, 7, 11}, + {2, 3, 5, 7, 11}, {1, 2, 0, 3, 4}), + {OpSharding::MANUAL, OpSharding::REPLICATED})); + EXPECT_EQ( + MoveAndMergeShardingTiles(sharding, 2, 0), + HloSharding::Subgroup(TileAssignment({2 * 5, 3, 1, 7, 11}, + {2, 3, 5, 7, 11}, {0, 2, 1, 3, 4}), + {OpSharding::MANUAL, OpSharding::REPLICATED})); +} + TEST(HloShardingUtilTest, TransposeShardingReplicated) { EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), HloSharding::Replicate());