Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Distributors to make it acutally do its jobs #68

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 19 additions & 31 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
)

state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
preconditioner_config=group[PRECONDITIONER_CONFIG],
beta2=group[BETAS][1],
epsilon=group[EPSILON],
Expand All @@ -537,18 +536,17 @@ def _instantiate_grafting(self) -> None:
state_lists[GRAFTING_PRECONDITIONER_LIST] = None
elif type(group[GRAFTING_CONFIG]) is SGDGraftingConfig:
state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
)
elif type(group[GRAFTING_CONFIG]) in (
AdaGradGraftingConfig,
RMSpropGraftingConfig,
AdamGraftingConfig,
):
state_lists[GRAFTING_PRECONDITIONER_LIST] = AdagradPreconditionerList(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
beta2=(
1.0
if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig
Expand All @@ -565,7 +563,7 @@ def _instantiate_grafting(self) -> None:
def _instantiate_steps(self) -> None:
for state_lists in self._per_group_state_lists:
assert (
len(state_lists[DISTRIBUTOR].global_block_info_list) > 0
len(state_lists[DISTRIBUTOR].local_block_info_list) > 0
), "There is no params in your param_group. Please check the instantiation of DistributedShampoo "
'with param_group containing no params. For example, DistributedShampoo(params=[{"params": []}])'
# NOTE: We instantiate a single step tensor on CPU for each group in order
Expand All @@ -575,7 +573,7 @@ def _instantiate_steps(self) -> None:

# In order to ensure that the step counter is checkpointed correctly, we store it
# as a tensor (which is replicated across all devices) under the first parameter's state.
block_info = state_lists[DISTRIBUTOR].global_block_info_list[0]
block_info = state_lists[DISTRIBUTOR].local_block_info_list[0]
self.state[block_info.param][STEP] = state_lists[STEP]

@torch.no_grad()
Expand All @@ -586,11 +584,11 @@ def _instantiate_momentum(self) -> None:
if group[MOMENTUM] == 0.0:
continue

# Construct global momentum list.
global_momentum_list = []
# Construct local momentum list.
local_momentum_list = []
for block, block_info in zip(
state_lists[DISTRIBUTOR].global_blocked_params,
state_lists[DISTRIBUTOR].global_block_info_list,
state_lists[DISTRIBUTOR].local_blocked_params,
state_lists[DISTRIBUTOR].local_block_info_list,
strict=True,
):
assert (
Expand All @@ -608,15 +606,9 @@ def _instantiate_momentum(self) -> None:
dtype=block.dtype,
device=block.device,
)
global_momentum_list.append(
block_info.get_tensor(block_state[MOMENTUM])
)
local_momentum_list.append(block_state[MOMENTUM])

# We compress the momentum list to only the locally-owned parameter states.
state_lists[MOMENTUM_LIST] = compress_list(
global_momentum_list,
state_lists[DISTRIBUTOR].distributor_selector,
)
state_lists[MOMENTUM_LIST] = local_momentum_list
# Here, we set masked momentum list to momentum list because we assume
# all parameters are active.
state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST]
Expand All @@ -629,11 +621,11 @@ def _instantiate_filtered_grads(self) -> None:
if group[BETAS][0] == 0.0:
continue

# Construct global filtered gradient list.
global_filtered_grad_list = []
# Construct local filtered gradient list.
local_filtered_grad_list = []
for block, block_info in zip(
state_lists[DISTRIBUTOR].global_blocked_params,
state_lists[DISTRIBUTOR].global_block_info_list,
state_lists[DISTRIBUTOR].local_blocked_params,
state_lists[DISTRIBUTOR].local_block_info_list,
strict=True,
):
assert (
Expand All @@ -651,15 +643,11 @@ def _instantiate_filtered_grads(self) -> None:
dtype=block.dtype,
device=block.device,
)
global_filtered_grad_list.append(
local_filtered_grad_list.append(
block_info.get_tensor(block_state[FILTERED_GRAD])
)

# We compress the momentum list to only the locally-owned parameter states.
state_lists[FILTERED_GRAD_LIST] = compress_list(
global_filtered_grad_list,
state_lists[DISTRIBUTOR].distributor_selector,
)
state_lists[FILTERED_GRAD_LIST] = local_filtered_grad_list
# Here, we set masked filtered grad list to filtered grad list because we assume
# all parameters are active.
state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[FILTERED_GRAD_LIST]
Expand Down
18 changes: 11 additions & 7 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def __init__(
)
)

self._construct_global_block_info_list(buffer_size_ranks)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
block_info.group_source_rank == group_rank
for block_info in self._global_block_info_list
for block_info in global_block_info_list
)
self._local_blocked_params: tuple[Tensor, ...] = compress_list(
self._global_blocked_params, self._distributor_selector
Expand All @@ -122,6 +123,9 @@ def __init__(
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._local_blocked_params
)
self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list(
global_block_info_list, self._distributor_selector
)

self._construct_distributed_buffers(
buffer_size_ranks=buffer_size_ranks,
Expand Down Expand Up @@ -257,18 +261,18 @@ def _distribute_buffer_sizes(

return tuple(buffer_size_ranks)

@torch.no_grad()
def _construct_global_block_info_list(
tsunghsienlee marked this conversation as resolved.
Show resolved Hide resolved
self, buffer_size_ranks: tuple[tuple[int, int], ...]
) -> None:
) -> tuple[DDPBlockInfo, ...]:
"""Construct the global block info list.

Args:
buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size
and an assigned rank for each block.

"""
# Construct global block info list.
self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple(
return tuple(
DDPBlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down Expand Up @@ -392,7 +396,7 @@ def _construct_distributed_buffers(
self._global_dist_buffer = torch.zeros(
total_buffer_size,
dtype=torch.int8,
device=self._global_block_info_list[0].param.device,
device=self._global_blocked_params[0].device,
)
local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum)
splitted_local_dist_buffers = DDPDistributor._split_local_dist_buffers(
Expand Down
33 changes: 13 additions & 20 deletions distributed_shampoo/utils/shampoo_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, param_group: dict[str, Any]) -> None:
self._local_blocked_params: tuple[Tensor, ...]
# Local masked blocked params are the parameters masked by the distributor selector AND the local grad selector.
self._local_masked_blocked_params: tuple[Tensor, ...]
# Global block info list contains information about each global block.
self._global_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...]
# Local block info list contains information about each block masked by the distributor selector.
self._local_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...]

@abstractmethod
@torch.no_grad()
Expand All @@ -73,14 +73,6 @@ def update_params(
masked_blocked_search_directions: tuple[Tensor, ...],
) -> None: ...

@property
def global_blocked_params(self) -> tuple[Tensor, ...]:
return self._global_blocked_params

@property
def distributor_selector(self) -> tuple[bool, ...]:
return self._distributor_selector

@property
def local_grad_selector(self) -> tuple[bool, ...]:
return self._local_grad_selector
Expand All @@ -94,8 +86,8 @@ def local_masked_blocked_params(self) -> tuple[Tensor, ...]:
return self._local_masked_blocked_params

@property
def global_block_info_list(self) -> tuple[BlockInfo, ...]:
return self._global_block_info_list
def local_block_info_list(self) -> tuple[BlockInfo, ...]:
return self._local_block_info_list

def _construct_composable_block_ids(
self,
Expand Down Expand Up @@ -258,18 +250,18 @@ def __init__(
param_group: dict[str, Any],
) -> None:
super().__init__(param_group)
self._construct_global_block_info_list()

# Initialize selectors and local blocked (masked) parameters.
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._global_blocked_params
)
self._distributor_selector: tuple[bool, ...] = self._local_grad_selector
self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params
self._local_masked_blocked_params: tuple[Tensor, ...] = (
self._global_blocked_params
self._local_blocked_params
)
self._local_blocked_params: tuple[Tensor, ...] = (
self._local_masked_blocked_params
self._local_block_info_list: tuple[BlockInfo, ...] = (
self._construct_local_block_info_list()
)

@torch.no_grad()
Expand All @@ -288,11 +280,12 @@ def update_params(
masked_blocked_search_directions,
)

def _construct_global_block_info_list(
@torch.no_grad()
def _construct_local_block_info_list(
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
self._global_block_info_list = tuple(
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
13 changes: 8 additions & 5 deletions distributed_shampoo/utils/shampoo_fsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self._global_num_blocks_per_split_param: tuple[int, ...] = ()

super().__init__(param_group)
self._construct_global_block_info_list()

# Initialize selectors and local blocked (masked) parameters.
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
Expand All @@ -66,6 +65,9 @@ def __init__(
self._global_blocked_params
)
self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params
self._local_block_info_list: tuple[BlockInfo, ...] = (
self._construct_local_block_info_list()
)

@torch.no_grad()
def update_params(
Expand Down Expand Up @@ -102,12 +104,13 @@ def _construct_composable_block_ids(
"""
return (param_index, f"rank_{rank}-block_{block_index}")

def _construct_global_block_info_list(
@torch.no_grad()
def _construct_local_block_info_list(
tsunghsienlee marked this conversation as resolved.
Show resolved Hide resolved
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
rank = dist.get_rank()
self._global_block_info_list = tuple(
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
8 changes: 4 additions & 4 deletions distributed_shampoo/utils/shampoo_fully_shard_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def _construct_composable_block_ids(
return (param_index, f"rank_{rank}-block_{block_index}")

@torch.no_grad()
def _construct_global_block_info_list(
def _construct_local_block_info_list(
self,
) -> None:
"""Construct global block info list from param_group and num_blocks_within_param."""
) -> tuple[BlockInfo, ...]:
"""Construct local block info list from param_group and num_blocks_within_param."""
rank = dist.get_rank()

# Call `super()` instead of `self` as a performance optimization.
Expand All @@ -77,7 +77,7 @@ def _construct_global_block_info_list(
lambda p: p.to_local().numel() > 0, # type: ignore[arg-type]
super()._get_params_or_grads(),
)
self._global_block_info_list = tuple(
return tuple(
BlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down
17 changes: 11 additions & 6 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,13 @@ def __init__(
)
)

self._construct_global_block_info_list(buffer_size_ranks)

global_block_info_list = self._construct_global_block_info_list(
buffer_size_ranks
)
# Initialize selectors and local blocked (masked) parameters.
self._distributor_selector: tuple[bool, ...] = tuple(
block_info.group_source_rank == comms_group_rank
for block_info in self._global_block_info_list
for block_info in global_block_info_list
)
self._local_blocked_params: tuple[Tensor, ...] = compress_list(
self._global_blocked_params, self._distributor_selector
Expand All @@ -216,6 +217,9 @@ def __init__(
self._local_grad_selector: tuple[bool, ...] = (True,) * len(
self._local_blocked_params
)
self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list(
global_block_info_list, self._distributor_selector
)

self._construct_distributed_buffers(
buffer_size_ranks=buffer_size_ranks,
Expand Down Expand Up @@ -369,14 +373,15 @@ def _construct_composable_block_ids(
"""
return (param_index, f"rank_{rank}-block_{block_index}")

@torch.no_grad()
def _construct_global_block_info_list(
tsunghsienlee marked this conversation as resolved.
Show resolved Hide resolved
self, buffer_size_ranks: tuple[tuple[int, int], ...]
) -> None:
) -> tuple[DDPBlockInfo, ...]:
"""Construct global block info list from param_group and num_blocks_within_param."""
# Note that for HSDP, we want to get the rank within each sharded group for the block id.
# When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group.
sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1)
self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple(
return tuple(
DDPBlockInfo(
param=param,
composable_block_ids=self._construct_composable_block_ids(
Expand Down Expand Up @@ -575,7 +580,7 @@ def _construct_distributed_buffers(
self._global_dist_buffer = torch.zeros(
total_buffer_size,
dtype=torch.int8,
device=self._global_block_info_list[0].param.device,
device=self._global_blocked_params[0].device,
)
local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum)
splitted_local_dist_buffers = HSDPDistributor._split_local_dist_buffers(
Expand Down
Loading
Loading