From 57d5dc9eeb0271659463c57125a6259c3b226a90 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Mon, 23 Dec 2024 11:42:35 -0800 Subject: [PATCH] Refactor Distributors to make it acutally do its jobs (#68) Summary: In the current Shampoo design, there are three major components: `DistributedShampoo`, which describes the high-level algorithm flow; `Distributor`, which manages the distribution of parameters for different computing paradigms (e.g., DDP, FSDP, HSDP, etc.); and `PreconditionerList`, which contains the detailed algorithm implementation. The `Distributor` manages the distribution of parameters and then sends them to `DistributedShampoo` which invokes the corresponding algorithms implemented by each `PreconditionerList`. Since the `Distributor` handles parameter distribution, ideally, downstream classes (i.e., `DistributedShampoo` and `PreconditionerList`) should not need to use `Distributor.distributor_selector` to compress parameters for their use. However, in the current implementation, such usages appear in both [`DistributedShampoo`](https://www.internalfb.com/code/fbsource/[20bcdb4983a16c2cd7f00995754851159da91ed3]/fbcode/hpc/optimizers/distributed_shampoo/dev/distributed_shampoo.py?lines=614-617) and [`PreconditionerList`](https://www.internalfb.com/code/fbsource/[20bcdb4983a16c2cd7f00995754851159da91ed3]/fbcode/hpc/optimizers/distributed_shampoo/dev/utils/shampoo_preconditioner_list.py?lines=145-146). The main changes are listed as follows: 1. `Distributor` now only provides access to the local version of `params` and `block_info` to align with its responsibilities. 2. `Distributor` no longer needs to store `global_block_info_list`, the global version of `block_info`; instead, it stores `local_block_info_list`. 3. `PreconditionerList` no longer requires `distributor_selector` as an input argument because `Distributor` already performs this task. Reviewed By: anana10c, chuanhaozhuge Differential Revision: D64708506 --- distributed_shampoo/distributed_shampoo.py | 50 +++++++------------ .../utils/shampoo_ddp_distributor.py | 18 ++++--- .../utils/shampoo_distributor.py | 33 +++++------- .../utils/shampoo_fsdp_distributor.py | 13 +++-- .../utils/shampoo_fully_shard_distributor.py | 8 +-- .../utils/shampoo_hsdp_distributor.py | 17 ++++--- .../utils/shampoo_hybrid_shard_distributor.py | 15 ++++-- .../utils/shampoo_preconditioner_list.py | 29 +++-------- .../utils/tests/shampoo_distributor_test.py | 27 ++-------- .../tests/shampoo_preconditioner_list_test.py | 18 +------ 10 files changed, 88 insertions(+), 140 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index b91fc1c..1089f64 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -516,10 +516,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None: ) state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, state=self.state, - block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, - distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, preconditioner_config=group[PRECONDITIONER_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], @@ -537,7 +536,7 @@ def _instantiate_grafting(self) -> None: state_lists[GRAFTING_PRECONDITIONER_LIST] = None elif type(group[GRAFTING_CONFIG]) is SGDGraftingConfig: state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, ) elif type(group[GRAFTING_CONFIG]) in ( AdaGradGraftingConfig, @@ -545,10 +544,9 @@ def _instantiate_grafting(self) -> None: AdamGraftingConfig, ): state_lists[GRAFTING_PRECONDITIONER_LIST] = AdagradPreconditionerList( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, state=self.state, - block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, - distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, beta2=( 1.0 if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig @@ -565,7 +563,7 @@ def _instantiate_grafting(self) -> None: def _instantiate_steps(self) -> None: for state_lists in self._per_group_state_lists: assert ( - len(state_lists[DISTRIBUTOR].global_block_info_list) > 0 + len(state_lists[DISTRIBUTOR].local_block_info_list) > 0 ), "There is no params in your param_group. Please check the instantiation of DistributedShampoo " 'with param_group containing no params. For example, DistributedShampoo(params=[{"params": []}])' # NOTE: We instantiate a single step tensor on CPU for each group in order @@ -575,7 +573,7 @@ def _instantiate_steps(self) -> None: # In order to ensure that the step counter is checkpointed correctly, we store it # as a tensor (which is replicated across all devices) under the first parameter's state. - block_info = state_lists[DISTRIBUTOR].global_block_info_list[0] + block_info = state_lists[DISTRIBUTOR].local_block_info_list[0] self.state[block_info.param][STEP] = state_lists[STEP] @torch.no_grad() @@ -586,11 +584,11 @@ def _instantiate_momentum(self) -> None: if group[MOMENTUM] == 0.0: continue - # Construct global momentum list. - global_momentum_list = [] + # Construct local momentum list. + local_momentum_list = [] for block, block_info in zip( - state_lists[DISTRIBUTOR].global_blocked_params, - state_lists[DISTRIBUTOR].global_block_info_list, + state_lists[DISTRIBUTOR].local_blocked_params, + state_lists[DISTRIBUTOR].local_block_info_list, strict=True, ): assert ( @@ -608,15 +606,9 @@ def _instantiate_momentum(self) -> None: dtype=block.dtype, device=block.device, ) - global_momentum_list.append( - block_info.get_tensor(block_state[MOMENTUM]) - ) + local_momentum_list.append(block_state[MOMENTUM]) - # We compress the momentum list to only the locally-owned parameter states. - state_lists[MOMENTUM_LIST] = compress_list( - global_momentum_list, - state_lists[DISTRIBUTOR].distributor_selector, - ) + state_lists[MOMENTUM_LIST] = local_momentum_list # Here, we set masked momentum list to momentum list because we assume # all parameters are active. state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST] @@ -629,11 +621,11 @@ def _instantiate_filtered_grads(self) -> None: if group[BETAS][0] == 0.0: continue - # Construct global filtered gradient list. - global_filtered_grad_list = [] + # Construct local filtered gradient list. + local_filtered_grad_list = [] for block, block_info in zip( - state_lists[DISTRIBUTOR].global_blocked_params, - state_lists[DISTRIBUTOR].global_block_info_list, + state_lists[DISTRIBUTOR].local_blocked_params, + state_lists[DISTRIBUTOR].local_block_info_list, strict=True, ): assert ( @@ -651,15 +643,11 @@ def _instantiate_filtered_grads(self) -> None: dtype=block.dtype, device=block.device, ) - global_filtered_grad_list.append( + local_filtered_grad_list.append( block_info.get_tensor(block_state[FILTERED_GRAD]) ) - # We compress the momentum list to only the locally-owned parameter states. - state_lists[FILTERED_GRAD_LIST] = compress_list( - global_filtered_grad_list, - state_lists[DISTRIBUTOR].distributor_selector, - ) + state_lists[FILTERED_GRAD_LIST] = local_filtered_grad_list # Here, we set masked filtered grad list to filtered grad list because we assume # all parameters are active. state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[FILTERED_GRAD_LIST] diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index b742133..eb76ff8 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -106,12 +106,13 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) - + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -122,6 +123,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -257,9 +261,10 @@ def _distribute_buffer_sizes( return tuple(buffer_size_ranks) + @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct the global block info list. Args: @@ -267,8 +272,7 @@ def _construct_global_block_info_list( and an assigned rank for each block. """ - # Construct global block info list. - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -392,7 +396,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = DDPDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_distributor.py b/distributed_shampoo/utils/shampoo_distributor.py index 3a007f9..0bbc123 100644 --- a/distributed_shampoo/utils/shampoo_distributor.py +++ b/distributed_shampoo/utils/shampoo_distributor.py @@ -63,8 +63,8 @@ def __init__(self, param_group: dict[str, Any]) -> None: self._local_blocked_params: tuple[Tensor, ...] # Local masked blocked params are the parameters masked by the distributor selector AND the local grad selector. self._local_masked_blocked_params: tuple[Tensor, ...] - # Global block info list contains information about each global block. - self._global_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...] + # Local block info list contains information about each block masked by the distributor selector. + self._local_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...] @abstractmethod @torch.no_grad() @@ -73,14 +73,6 @@ def update_params( masked_blocked_search_directions: tuple[Tensor, ...], ) -> None: ... - @property - def global_blocked_params(self) -> tuple[Tensor, ...]: - return self._global_blocked_params - - @property - def distributor_selector(self) -> tuple[bool, ...]: - return self._distributor_selector - @property def local_grad_selector(self) -> tuple[bool, ...]: return self._local_grad_selector @@ -94,8 +86,8 @@ def local_masked_blocked_params(self) -> tuple[Tensor, ...]: return self._local_masked_blocked_params @property - def global_block_info_list(self) -> tuple[BlockInfo, ...]: - return self._global_block_info_list + def local_block_info_list(self) -> tuple[BlockInfo, ...]: + return self._local_block_info_list def _construct_composable_block_ids( self, @@ -258,18 +250,18 @@ def __init__( param_group: dict[str, Any], ) -> None: super().__init__(param_group) - self._construct_global_block_info_list() # Initialize selectors and local blocked (masked) parameters. self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._global_blocked_params ) self._distributor_selector: tuple[bool, ...] = self._local_grad_selector + self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params self._local_masked_blocked_params: tuple[Tensor, ...] = ( - self._global_blocked_params + self._local_blocked_params ) - self._local_blocked_params: tuple[Tensor, ...] = ( - self._local_masked_blocked_params + self._local_block_info_list: tuple[BlockInfo, ...] = ( + self._construct_local_block_info_list() ) @torch.no_grad() @@ -288,11 +280,12 @@ def update_params( masked_blocked_search_directions, ) - def _construct_global_block_info_list( + @torch.no_grad() + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" - self._global_block_info_list = tuple( + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_fsdp_distributor.py b/distributed_shampoo/utils/shampoo_fsdp_distributor.py index 852f598..b80d9fe 100644 --- a/distributed_shampoo/utils/shampoo_fsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_fsdp_distributor.py @@ -55,7 +55,6 @@ def __init__( self._global_num_blocks_per_split_param: tuple[int, ...] = () super().__init__(param_group) - self._construct_global_block_info_list() # Initialize selectors and local blocked (masked) parameters. self._local_grad_selector: tuple[bool, ...] = (True,) * len( @@ -66,6 +65,9 @@ def __init__( self._global_blocked_params ) self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params + self._local_block_info_list: tuple[BlockInfo, ...] = ( + self._construct_local_block_info_list() + ) @torch.no_grad() def update_params( @@ -102,12 +104,13 @@ def _construct_composable_block_ids( """ return (param_index, f"rank_{rank}-block_{block_index}") - def _construct_global_block_info_list( + @torch.no_grad() + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" rank = dist.get_rank() - self._global_block_info_list = tuple( + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py index e63b359..47c7a2f 100644 --- a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py @@ -65,10 +65,10 @@ def _construct_composable_block_ids( return (param_index, f"rank_{rank}-block_{block_index}") @torch.no_grad() - def _construct_global_block_info_list( + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" rank = dist.get_rank() # Call `super()` instead of `self` as a performance optimization. @@ -77,7 +77,7 @@ def _construct_global_block_info_list( lambda p: p.to_local().numel() > 0, # type: ignore[arg-type] super()._get_params_or_grads(), ) - self._global_block_info_list = tuple( + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index ae5874d..e8042ce 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -200,12 +200,13 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) - + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == comms_group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -216,6 +217,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -369,14 +373,15 @@ def _construct_composable_block_ids( """ return (param_index, f"rank_{rank}-block_{block_index}") + @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct global block info list from param_group and num_blocks_within_param.""" # Note that for HSDP, we want to get the rank within each sharded group for the block id. # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1) - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -575,7 +580,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = HSDPDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py index 3f04d09..795c36d 100644 --- a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py @@ -186,12 +186,14 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == comms_group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -202,6 +204,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -375,7 +380,7 @@ def _construct_composable_block_ids( @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct global block info list from param_group and num_blocks_within_param.""" # Call `super()` instead of `self` as a performance optimization. # This leads to O(1) instead of O(N) complexity to retrieve the parameters. @@ -387,7 +392,7 @@ def _construct_global_block_info_list( # Note that for HybridShard, we want to get the rank within each sharded group for the block id. # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. sharded_group_rank = self._hybrid_shard_device_mesh.get_local_rank(1) - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -512,7 +517,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = HybridShardDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 85b2f99..9e20ad0 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -143,8 +143,6 @@ class AdagradPreconditionerList(PreconditionerList): state (Mapping[Tensor, Any]): Mapping containing optimizer state. block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter - is selected by the current Distributor. beta2 (float): Exponential moving average factor for Adam/RMSprop second moment state. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-10) @@ -158,7 +156,6 @@ def __init__( # type: ignore state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], - distributor_selector: tuple[bool, ...], beta2: float = 1.0, epsilon: float = 1e-10, use_bias_correction: bool = True, @@ -176,7 +173,7 @@ def __init__( # and do not explicitly store them as AdagradPreconditionerList attributes here. # This is because the optimizer state is defined per-parameter, but AdagradPreconditionerList is defined # across each parameter group (which includes multiple parameters). - preconditioner_list = [] + preconditioner_list: list[Tensor] = [] for block, block_info in zip(block_list, block_info_list, strict=True): param_index, block_index = block_info.composable_block_ids if block_index not in state[block_info.param]: @@ -198,17 +195,12 @@ def __init__( ) # Masked lists are the list of active preconditioners or values after filtering out gradients with None. - self._local_preconditioner_list: tuple[Tensor, ...] = compress_list( - preconditioner_list, distributor_selector - ) + self._local_preconditioner_list: tuple[Tensor, ...] = tuple(preconditioner_list) self._masked_preconditioner_list: tuple[Tensor, ...] = ( self._local_preconditioner_list ) - # Construct lists of dims, bytes, and numels for logging purposes. - self._dims_list: tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) + # Construct lists of numels and bytes for logging purposes. self._numel_list: tuple[int, ...] = tuple( preconditioner.numel() for preconditioner in self._local_preconditioner_list ) @@ -363,8 +355,6 @@ class BaseShampooPreconditionerList( state (Mapping[Tensor, Any]): Mapping containing optimizer state. block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter - is selected by the current Distributor. preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) @@ -384,7 +374,6 @@ def __init__( # type: ignore state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], - distributor_selector: tuple[bool, ...], preconditioner_config: PreconditionerConfig, beta2: float = 1.0, epsilon: float = 1e-12, @@ -416,7 +405,6 @@ def __init__( self._initialize_state_lists( block_list=block_list, kronecker_factors_list=kronecker_factors_list, - distributor_selector=distributor_selector, ) def _create_base_kronecker_factors( @@ -702,16 +690,14 @@ def _initialize_state_lists( self, block_list: tuple[Tensor, ...], kronecker_factors_list: list[ShampooKroneckerFactorsListType], - distributor_selector: tuple[bool, ...], ) -> None: # Initialize local lists. - local_block_list = compress_list(block_list, distributor_selector) self._local_kronecker_factors_list: tuple[ ShampooKroneckerFactorsListType, ..., - ] = compress_list(kronecker_factors_list, distributor_selector) + ] = tuple(kronecker_factors_list) self._local_order_list: tuple[int, ...] = tuple( - block.dim() for block in local_block_list + block.dim() for block in block_list ) self._local_root_list: tuple[int, ...] = self._get_inverse_roots_from_override( self._inv_root_override, @@ -734,9 +720,6 @@ def _initialize_state_lists( # Construct lists of bytes and numels for logging purposes. # NOTE: These lists are constructed across all blocked parameters. - self._dims_list: tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) self._numel_list: tuple[int, ...] = tuple( sum(2 * dim**2 for dim in dims) for dims in self._dims_list ) @@ -744,7 +727,7 @@ def _initialize_state_lists( numel * (get_dtype_size(self._factor_matrix_dtype) + get_dtype_size(block.dtype)) // 2 - for numel, block in zip(self._numel_list, local_block_list, strict=True) + for numel, block in zip(self._numel_list, block_list, strict=True) ) def compress_preconditioner_list( diff --git a/distributed_shampoo/utils/tests/shampoo_distributor_test.py b/distributed_shampoo/utils/tests/shampoo_distributor_test.py index 5825930..ff23203 100644 --- a/distributed_shampoo/utils/tests/shampoo_distributor_test.py +++ b/distributed_shampoo/utils/tests/shampoo_distributor_test.py @@ -90,14 +90,6 @@ def test_update_params(self) -> None: actual_masked_blocked_params, expected_masked_blocked_params ) - def test_distributor_selector(self) -> None: - # Two blocks from the linear layer, and one block from the bias layer. - expected_distributor_selector = (True, True, True) - self.assertEqual( - self._distributor.distributor_selector, - expected_distributor_selector, - ) - def test_local_grad_selector(self) -> None: # Explicitly disable the gradient of the bias layer and call merge_and_block_gradients() # to update the local gradient selector for the bias layer (i.e., 3rd block). @@ -111,17 +103,6 @@ def test_local_grad_selector(self) -> None: expected_local_grad_selector, ) - def test_global_blocked_params(self) -> None: - expected_global_params = ( - torch.zeros(5, 5, dtype=torch.float), - torch.zeros(5, 5, dtype=torch.float), - torch.zeros(5, dtype=torch.float), - ) - torch.testing.assert_close( - self._distributor.global_blocked_params, - expected_global_params, - ) - def test_local_blocked_params(self) -> None: # In Distributor, because there is no global vs. local boundary concept, # global and local blocked params are always identical. @@ -135,8 +116,8 @@ def test_local_blocked_params(self) -> None: expected_local_params, ) - def test_global_block_info_list(self) -> None: - expected_global_block_info_list = ( + def test_local_block_info_list(self) -> None: + expected_local_block_info_list = ( BlockInfo( param=self._model.linear_layers[0].weight, composable_block_ids=(0, "block_0"), @@ -151,8 +132,8 @@ def test_global_block_info_list(self) -> None: ), ) self.assertEqual( - self._distributor.global_block_info_list, - expected_global_block_info_list, + self._distributor.local_block_info_list, + expected_local_block_info_list, ) def test_merge_and_block_gradients(self) -> None: diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 8c4a926..d1fba5b 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -167,11 +167,10 @@ def test_compress_preconditioner_list(self) -> None: class AdagradPreconditionerListTest(PreconditionerListTest): def _instantiate_block_list(self) -> tuple[Tensor, ...]: # Because maximum_preconditioner_dim = 2, self._params[0] forms a block by itself, - # self._params[1] are split into two blocks, and self._params[2] forms a block by itself. + # and self._params[1] are split into two blocks. return ( self._params[0], *torch.split(self._params[1], 2, dim=0), - self._params[2], ) def _instantiate_preconditioner_list( @@ -182,7 +181,6 @@ def _instantiate_preconditioner_list( block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, **kwargs, ) @@ -190,16 +188,13 @@ def setUp(self) -> None: self._params = ( torch.tensor([1.0, 2.0]), torch.arange(6, dtype=torch.float).reshape(3, 2), - # Following param will not be used due to the distributor selector below. - torch.tensor([torch.nan, torch.nan]), ) self._state = { # type: ignore[var-annotated] self._params[0]: {}, self._params[1]: {}, - self._params[2]: {}, } # Because maximum_preconditioner_dim = 2, self._params[0] forms a block by itself, - # self._params[1] are split into two blocks, and self._params[2] forms a block by itself. + # and self._params[1] are split into two blocks. self._block_info_list = ( BlockInfo( param=self._params[0], @@ -213,13 +208,7 @@ def setUp(self) -> None: param=self._params[1], composable_block_ids=(1, "block_1"), ), - BlockInfo( - param=self._params[2], - composable_block_ids=(2, "block_0"), - ), ) - # Ignores the last block, which is self._params[2] itself. - self._distributor_selector = (True, True, True, False) super().setUp() def test_update_preconditioners_and_precondition(self) -> None: @@ -295,7 +284,6 @@ def test_abstract_methods(self) -> None: composable_block_ids=(0, "block_0"), ), ), - distributor_selector=(True,), preconditioner_config=DefaultShampooConfig, beta2=1.0, ) @@ -627,7 +615,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] ) @@ -871,7 +858,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] )