Skip to content

Commit

Permalink
Simplify the input argument of _construct_global_block_info_list (#70)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #70

Instead of passing the entire `buffer_size_ranks` tuple, the `group_source_ranks` tuple is extracted from it and passed as an argument. This simplifies the code and makes it more readable. Furthermore, this might help the consolidation of `_construct_global_block_info_list` and `_construct_local_block_info_list` in the future.

Reviewed By: anana10c

Differential Revision: D67606282

fbshipit-source-id: 7aefea61be746f8ae967d0e21bc32708e52b20c1
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 23, 2024
1 parent 618d857 commit 7b14418
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
19 changes: 13 additions & 6 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(
)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
group_source_ranks=tuple(
group_source_rank for _, group_source_rank in buffer_size_ranks
)
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
Expand Down Expand Up @@ -263,14 +265,19 @@ def _distribute_buffer_sizes(

@torch.no_grad()
def _construct_global_block_info_list(
self, buffer_size_ranks: tuple[tuple[int, int], ...]
self, group_source_ranks: tuple[int, ...]
) -> tuple[DDPBlockInfo, ...]:
"""Construct the global block info list.
This method creates a list of DDPBlockInfo objects, which contain information
about each parameter block, including its composable block IDs, a function to
allocate zero tensors, a method to retrieve tensors, and the group source rank.
Args:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size
and an assigned rank for each block.
group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block.
Returns:
tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block.
"""
return tuple(
DDPBlockInfo(
Expand Down Expand Up @@ -298,9 +305,9 @@ def _construct_global_block_info_list(
generate_pairwise_indices(self._global_num_blocks_per_param),
strict=True,
)
for block_index, (_, group_source_rank) in enumerate(
for block_index, group_source_rank in enumerate(
islice(
buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end
group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end
)
)
)
Expand Down
23 changes: 18 additions & 5 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def __init__(
)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
group_source_ranks=tuple(
group_source_rank for _, group_source_rank in buffer_size_ranks
)
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
Expand Down Expand Up @@ -375,9 +377,20 @@ def _construct_composable_block_ids(

@torch.no_grad()
def _construct_global_block_info_list(
self, buffer_size_ranks: tuple[tuple[int, int], ...]
self, group_source_ranks: tuple[int, ...]
) -> tuple[DDPBlockInfo, ...]:
"""Construct global block info list from param_group and num_blocks_within_param."""
"""Construct the global block info list.
This method creates a list of DDPBlockInfo objects, which contain information
about each parameter block, including its composable block IDs, a function to
allocate zero tensors, a method to retrieve tensors, and the group source rank.
Args:
group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block.
Returns:
tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block.
"""
# Note that for HSDP, we want to get the rank within each sharded group for the block id.
# When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group.
sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1)
Expand Down Expand Up @@ -408,9 +421,9 @@ def _construct_global_block_info_list(
generate_pairwise_indices(self._global_num_blocks_per_param),
strict=True,
)
for block_index, (_, group_source_rank) in enumerate(
for block_index, group_source_rank in enumerate(
islice(
buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end
group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end
)
)
)
Expand Down
23 changes: 18 additions & 5 deletions distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def __init__(
)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
group_source_ranks=tuple(
group_source_rank for _, group_source_rank in buffer_size_ranks
)
)

# Initialize selectors and local blocked (masked) parameters.
Expand Down Expand Up @@ -379,9 +381,20 @@ def _construct_composable_block_ids(

@torch.no_grad()
def _construct_global_block_info_list(
self, buffer_size_ranks: tuple[tuple[int, int], ...]
self, group_source_ranks: tuple[int, ...]
) -> tuple[DDPBlockInfo, ...]:
"""Construct global block info list from param_group and num_blocks_within_param."""
"""Construct the global block info list.
This method creates a list of DDPBlockInfo objects, which contain information
about each parameter block, including its composable block IDs, a function to
allocate zero tensors, a method to retrieve tensors, and the group source rank.
Args:
group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block.
Returns:
tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block.
"""
# Call `super()` instead of `self` as a performance optimization.
# This leads to O(1) instead of O(N) complexity to retrieve the parameters.
non_empty_params: Iterable[DTensor] = filter(
Expand Down Expand Up @@ -419,9 +432,9 @@ def _construct_global_block_info_list(
generate_pairwise_indices(self._global_num_blocks_per_param),
strict=True,
)
for block_index, (_, group_source_rank) in enumerate(
for block_index, group_source_rank in enumerate(
islice(
buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end
group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end
)
)
)
Expand Down

0 comments on commit 7b14418

Please sign in to comment.