Skip to content

Commit

Permalink
Refactor Distributors to make it acutally do its jobs (#68)
Browse files Browse the repository at this point in the history
Summary:

In the current Shampoo design, there are three major components: `DistributedShampoo`, which describes the high-level algorithm flow; `Distributor`, which manages the distribution of parameters for different computing paradigms (e.g., DDP, FSDP, HSDP, etc.); and `PreconditionerList`, which contains the detailed algorithm implementation.

The `Distributor` manages the distribution of parameters and then sends them to `DistributedShampoo` which invokes the corresponding algorithms implemented by each `PreconditionerList`.

Since the `Distributor` handles parameter distribution, ideally, downstream classes (i.e., `DistributedShampoo` and `PreconditionerList`) should not need to use `Distributor.distributor_selector` to compress parameters for their use. However, in the current implementation, such usages appear in both [`DistributedShampoo`](https://www.internalfb.com/code/fbsource/[20bcdb4983a16c2cd7f00995754851159da91ed3]/fbcode/hpc/optimizers/distributed_shampoo/dev/distributed_shampoo.py?lines=614-617) and [`PreconditionerList`](https://www.internalfb.com/code/fbsource/[20bcdb4983a16c2cd7f00995754851159da91ed3]/fbcode/hpc/optimizers/distributed_shampoo/dev/utils/shampoo_preconditioner_list.py?lines=145-146).

The main changes are listed as follows:
1. `Distributor` now only provides access to the local version of `params` and `block_info` to align with its responsibilities.
2. `Distributor` no longer needs to store `global_block_info_list`, the global version of `block_info`; instead, it stores `local_block_info_list`.
3. `PreconditionerList` no longer requires `distributor_selector` as an input argument because `Distributor` already performs this task.

Reviewed By: anana10c, chuanhaozhuge

Differential Revision: D64708506
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 23, 2024
1 parent 0921f45 commit 57d5dc9
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 140 deletions.
50 changes: 19 additions & 31 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
)

state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
preconditioner_config=group[PRECONDITIONER_CONFIG],
beta2=group[BETAS][1],
epsilon=group[EPSILON],
Expand All @@ -537,18 +536,17 @@ def _instantiate_grafting(self) -> None:
state_lists[GRAFTING_PRECONDITIONER_LIST] = None
elif type(group[GRAFTING_CONFIG]) is SGDGraftingConfig:
state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
)
elif type(group[GRAFTING_CONFIG]) in (
AdaGradGraftingConfig,
RMSpropGraftingConfig,
AdamGraftingConfig,
):
state_lists[GRAFTING_PRECONDITIONER_LIST] = AdagradPreconditionerList(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
beta2=(
1.0
if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig
Expand All @@ -565,7 +563,7 @@ def _instantiate_grafting(self) -> None:
def _instantiate_steps(self) -> None:
for state_lists in self._per_group_state_lists:
assert (
len(state_lists[DISTRIBUTOR].global_block_info_list) > 0
len(state_lists[DISTRIBUTOR].local_block_info_list) > 0
), "There is no params in your param_group. Please check the instantiation of DistributedShampoo "
'with param_group containing no params. For example, DistributedShampoo(params=[{"params": []}])'
# NOTE: We instantiate a single step tensor on CPU for each group in order
Expand All @@ -575,7 +573,7 @@ def _instantiate_steps(self) -> None:

# In order to ensure that the step counter is checkpointed correctly, we store it
# as a tensor (which is replicated across all devices) under the first parameter's state.
block_info = state_lists[DISTRIBUTOR].global_block_info_list[0]
block_info = state_lists[DISTRIBUTOR].local_block_info_list[0]
self.state[block_info.param][STEP] = state_lists[STEP]

@torch.no_grad()
Expand All @@ -586,11 +584,11 @@ def _instantiate_momentum(self) -> None:
if group[MOMENTUM] == 0.0:
continue

# Construct global momentum list.
global_momentum_list = []
# Construct local momentum list.
local_momentum_list = []
for block, block_info in zip(
state_lists[DISTRIBUTOR].global_blocked_params,
state_lists[DISTRIBUTOR].global_block_info_list,
state_lists[DISTRIBUTOR].local_blocked_params,
state_lists[DISTRIBUTOR].local_block_info_list,
strict=True,
):
assert (
Expand All @@ -608,15 +606,9 @@ def _instantiate_momentum(self) -> None:
dtype=block.dtype,
device=block.device,
)
global_momentum_list.append(
block_info.get_tensor(block_state[MOMENTUM])
)
local_momentum_list.append(block_state[MOMENTUM])

# We compress the momentum list to only the locally-owned parameter states.
state_lists[MOMENTUM_LIST] = compress_list(
global_momentum_list,
state_lists[DISTRIBUTOR].distributor_selector,
)
state_lists[MOMENTUM_LIST] = local_momentum_list
# Here, we set masked momentum list to momentum list because we assume
# all parameters are active.
state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST]
Expand All @@ -629,11 +621,11 @@ def _instantiate_filtered_grads(self) -> None:
if group[BETAS][0] == 0.0:
continue

# Construct global filtered gradient list.
global_filtered_grad_list = []
# Construct local filtered gradient list.
local_filtered_grad_list = []
for block, block_info in zip(
state_lists[DISTRIBUTOR].global_blocked_params,
state_lists[DISTRIBUTOR].global_block_info_list,
state_lists[DISTRIBUTOR].local_blocked_params,
state_lists[DISTRIBUTOR].local_block_info_list,
strict=True,
):
assert (
Expand All @@ -651,15 +643,11 @@ def _instantiate_filtered_grads(self) -> None:
dtype=block.dtype,
device=block.device,
)
global_filtered_grad_list.append(
local_filtered_grad_list.append(
block_info.get_tensor(block_state[FILTERED_GRAD])
)

# We compress the momentum list to only the locally-owned parameter states.
state_lists[FILTERED_GRAD_LIST] = compress_list(
global_filtered_grad_list,
state_lists[DISTRIBUTOR].distributor_selector,
)
state_lists[FILTERED_GRAD_LIST] = local_filtered_grad_list
# Here, we set masked filtered grad list to filtered grad list because we assume
# all parameters are active.
state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[FILTERED_GRAD_LIST]
Expand Down
18 changes: 11 additions & 7 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def __init__(
)
)

self._construct_global_block_info_list(buffer_size_ranks)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
block_info.group_source_rank == group_rank
for block_info in self._global_block_info_list
for block_info in global_block_info_list
)
self._local_blocked_params: tuple[Tensor, ...] = compress_list(
self._global_blocked_params, self._distributor_selector
Expand All @@ -122,6 +123,9 @@ def __init__(
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._local_blocked_params
)
self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list(
global_block_info_list, self._distributor_selector
)

self._construct_distributed_buffers(
buffer_size_ranks=buffer_size_ranks,
Expand Down Expand Up @@ -257,18 +261,18 @@ def _distribute_buffer_sizes(

return tuple(buffer_size_ranks)

@torch.no_grad()
def _construct_global_block_info_list(
self, buffer_size_ranks: tuple[tuple[int, int], ...]
) -> None:
) -> tuple[DDPBlockInfo, ...]:
"""Construct the global block info list.
Args:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size
and an assigned rank for each block.
"""
# Construct global block info list.
self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple(
return tuple(
DDPBlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down Expand Up @@ -392,7 +396,7 @@ def _construct_distributed_buffers(
self._global_dist_buffer = torch.zeros(
total_buffer_size,
dtype=torch.int8,
device=self._global_block_info_list[0].param.device,
device=self._global_blocked_params[0].device,
)
local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum)
splitted_local_dist_buffers = DDPDistributor._split_local_dist_buffers(
Expand Down
33 changes: 13 additions & 20 deletions distributed_shampoo/utils/shampoo_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, param_group: dict[str, Any]) -> None:
self._local_blocked_params: tuple[Tensor, ...]
# Local masked blocked params are the parameters masked by the distributor selector AND the local grad selector.
self._local_masked_blocked_params: tuple[Tensor, ...]
# Global block info list contains information about each global block.
self._global_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...]
# Local block info list contains information about each block masked by the distributor selector.
self._local_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...]

@abstractmethod
@torch.no_grad()
Expand All @@ -73,14 +73,6 @@ def update_params(
masked_blocked_search_directions: tuple[Tensor, ...],
) -> None: ...

@property
def global_blocked_params(self) -> tuple[Tensor, ...]:
return self._global_blocked_params

@property
def distributor_selector(self) -> tuple[bool, ...]:
return self._distributor_selector

@property
def local_grad_selector(self) -> tuple[bool, ...]:
return self._local_grad_selector
Expand All @@ -94,8 +86,8 @@ def local_masked_blocked_params(self) -> tuple[Tensor, ...]:
return self._local_masked_blocked_params

@property
def global_block_info_list(self) -> tuple[BlockInfo, ...]:
return self._global_block_info_list
def local_block_info_list(self) -> tuple[BlockInfo, ...]:
return self._local_block_info_list

def _construct_composable_block_ids(
self,
Expand Down Expand Up @@ -258,18 +250,18 @@ def __init__(
param_group: dict[str, Any],
) -> None:
super().__init__(param_group)
self._construct_global_block_info_list()

# Initialize selectors and local blocked (masked) parameters.
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._global_blocked_params
)
self._distributor_selector: tuple[bool, ...] = self._local_grad_selector
self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params
self._local_masked_blocked_params: tuple[Tensor, ...] = (
self._global_blocked_params
self._local_blocked_params
)
self._local_blocked_params: tuple[Tensor, ...] = (
self._local_masked_blocked_params
self._local_block_info_list: tuple[BlockInfo, ...] = (
self._construct_local_block_info_list()
)

@torch.no_grad()
Expand All @@ -288,11 +280,12 @@ def update_params(
masked_blocked_search_directions,
)

def _construct_global_block_info_list(
@torch.no_grad()
def _construct_local_block_info_list(
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
self._global_block_info_list = tuple(
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
13 changes: 8 additions & 5 deletions distributed_shampoo/utils/shampoo_fsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self._global_num_blocks_per_split_param: tuple[int, ...] = ()

super().__init__(param_group)
self._construct_global_block_info_list()

# Initialize selectors and local blocked (masked) parameters.
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
Expand All @@ -66,6 +65,9 @@ def __init__(
self._global_blocked_params
)
self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params
self._local_block_info_list: tuple[BlockInfo, ...] = (
self._construct_local_block_info_list()
)

@torch.no_grad()
def update_params(
Expand Down Expand Up @@ -102,12 +104,13 @@ def _construct_composable_block_ids(
"""
return (param_index, f"rank_{rank}-block_{block_index}")

def _construct_global_block_info_list(
@torch.no_grad()
def _construct_local_block_info_list(
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
rank = dist.get_rank()
self._global_block_info_list = tuple(
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
8 changes: 4 additions & 4 deletions distributed_shampoo/utils/shampoo_fully_shard_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def _construct_composable_block_ids(
return (param_index, f"rank_{rank}-block_{block_index}")

@torch.no_grad()
def _construct_global_block_info_list(
def _construct_local_block_info_list(
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
rank = dist.get_rank()

# Call `super()` instead of `self` as a performance optimization.
Expand All @@ -77,7 +77,7 @@ def _construct_global_block_info_list(
lambda p: p.to_local().numel() > 0, # type: ignore[arg-type]
super()._get_params_or_grads(),
)
self._global_block_info_list = tuple(
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
17 changes: 11 additions & 6 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,13 @@ def __init__(
)
)

self._construct_global_block_info_list(buffer_size_ranks)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
block_info.group_source_rank == comms_group_rank
for block_info in self._global_block_info_list
for block_info in global_block_info_list
)
self._local_blocked_params: tuple[Tensor, ...] = compress_list(
self._global_blocked_params, self._distributor_selector
Expand All @@ -216,6 +217,9 @@ def __init__(
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._local_blocked_params
)
self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list(
global_block_info_list, self._distributor_selector
)

self._construct_distributed_buffers(
buffer_size_ranks=buffer_size_ranks,
Expand Down Expand Up @@ -369,14 +373,15 @@ def _construct_composable_block_ids(
"""
return (param_index, f"rank_{rank}-block_{block_index}")

@torch.no_grad()
def _construct_global_block_info_list(
self, buffer_size_ranks: tuple[tuple[int, int], ...]
) -> None:
) -> tuple[DDPBlockInfo, ...]:
"""Construct global block info list from param_group and num_blocks_within_param."""
# Note that for HSDP, we want to get the rank within each sharded group for the block id.
# When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group.
sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1)
self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple(
return tuple(
DDPBlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down Expand Up @@ -575,7 +580,7 @@ def _construct_distributed_buffers(
self._global_dist_buffer = torch.zeros(
total_buffer_size,
dtype=torch.int8,
device=self._global_block_info_list[0].param.device,
device=self._global_blocked_params[0].device,
)
local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum)
splitted_local_dist_buffers = HSDPDistributor._split_local_dist_buffers(
Expand Down
Loading

0 comments on commit 57d5dc9

Please sign in to comment.