diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index a922d61..ac77084 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -516,10 +516,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None: ) state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, state=self.state, - block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, - distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, preconditioner_config=group[PRECONDITIONER_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], @@ -537,7 +536,7 @@ def _instantiate_grafting(self) -> None: state_lists[GRAFTING_PRECONDITIONER_LIST] = None elif type(group[GRAFTING_CONFIG]) is SGDGraftingConfig: state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, ) elif type(group[GRAFTING_CONFIG]) in ( AdaGradGraftingConfig, @@ -545,10 +544,9 @@ def _instantiate_grafting(self) -> None: AdamGraftingConfig, ): state_lists[GRAFTING_PRECONDITIONER_LIST] = AdagradPreconditionerList( - block_list=state_lists[DISTRIBUTOR].global_blocked_params, + block_list=state_lists[DISTRIBUTOR].local_blocked_params, state=self.state, - block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, - distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, beta2=( 1.0 if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig @@ -565,7 +563,7 @@ def _instantiate_grafting(self) -> None: def _instantiate_steps(self) -> None: for state_lists in self._per_group_state_lists: assert ( - len(state_lists[DISTRIBUTOR].global_block_info_list) > 0 + len(state_lists[DISTRIBUTOR].local_block_info_list) > 0 ), "There is no params in your param_group. Please check the instantiation of DistributedShampoo " 'with param_group containing no params. For example, DistributedShampoo(params=[{"params": []}])' # NOTE: We instantiate a single step tensor on CPU for each group in order @@ -575,7 +573,7 @@ def _instantiate_steps(self) -> None: # In order to ensure that the step counter is checkpointed correctly, we store it # as a tensor (which is replicated across all devices) under the first parameter's state. - block_info = state_lists[DISTRIBUTOR].global_block_info_list[0] + block_info = state_lists[DISTRIBUTOR].local_block_info_list[0] self.state[block_info.param][STEP] = state_lists[STEP] @torch.no_grad() @@ -586,11 +584,11 @@ def _instantiate_momentum(self) -> None: if group[MOMENTUM] == 0.0: continue - # Construct global momentum list. - global_momentum_list = [] + # Construct local momentum list. + local_momentum_list = [] for block, block_info in zip( - state_lists[DISTRIBUTOR].global_blocked_params, - state_lists[DISTRIBUTOR].global_block_info_list, + state_lists[DISTRIBUTOR].local_blocked_params, + state_lists[DISTRIBUTOR].local_block_info_list, strict=True, ): assert ( @@ -608,15 +606,9 @@ def _instantiate_momentum(self) -> None: dtype=block.dtype, device=block.device, ) - global_momentum_list.append( - block_info.get_tensor(block_state[MOMENTUM]) - ) + local_momentum_list.append(block_state[MOMENTUM]) - # We compress the momentum list to only the locally-owned parameter states. - state_lists[MOMENTUM_LIST] = compress_list( - global_momentum_list, - state_lists[DISTRIBUTOR].distributor_selector, - ) + state_lists[MOMENTUM_LIST] = local_momentum_list # Here, we set masked momentum list to momentum list because we assume # all parameters are active. state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST] @@ -629,11 +621,11 @@ def _instantiate_filtered_grads(self) -> None: if group[BETAS][0] == 0.0: continue - # Construct global filtered gradient list. - global_filtered_grad_list = [] + # Construct local filtered gradient list. + local_filtered_grad_list = [] for block, block_info in zip( - state_lists[DISTRIBUTOR].global_blocked_params, - state_lists[DISTRIBUTOR].global_block_info_list, + state_lists[DISTRIBUTOR].local_blocked_params, + state_lists[DISTRIBUTOR].local_block_info_list, strict=True, ): assert ( @@ -651,15 +643,11 @@ def _instantiate_filtered_grads(self) -> None: dtype=block.dtype, device=block.device, ) - global_filtered_grad_list.append( + local_filtered_grad_list.append( block_info.get_tensor(block_state[FILTERED_GRAD]) ) - # We compress the momentum list to only the locally-owned parameter states. - state_lists[FILTERED_GRAD_LIST] = compress_list( - global_filtered_grad_list, - state_lists[DISTRIBUTOR].distributor_selector, - ) + state_lists[FILTERED_GRAD_LIST] = local_filtered_grad_list # Here, we set masked filtered grad list to filtered grad list because we assume # all parameters are active. state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[FILTERED_GRAD_LIST] diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index e2d949f..ea8aa15 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -106,12 +106,13 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) - + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -122,6 +123,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -257,9 +261,10 @@ def _distribute_buffer_sizes( return tuple(buffer_size_ranks) + @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct the global block info list. Args: @@ -267,8 +272,7 @@ def _construct_global_block_info_list( and an assigned rank for each block. """ - # Construct global block info list. - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -392,7 +396,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = DDPDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_distributor.py b/distributed_shampoo/utils/shampoo_distributor.py index 3a007f9..0bbc123 100644 --- a/distributed_shampoo/utils/shampoo_distributor.py +++ b/distributed_shampoo/utils/shampoo_distributor.py @@ -63,8 +63,8 @@ def __init__(self, param_group: dict[str, Any]) -> None: self._local_blocked_params: tuple[Tensor, ...] # Local masked blocked params are the parameters masked by the distributor selector AND the local grad selector. self._local_masked_blocked_params: tuple[Tensor, ...] - # Global block info list contains information about each global block. - self._global_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...] + # Local block info list contains information about each block masked by the distributor selector. + self._local_block_info_list: tuple[BlockInfo, ...] | tuple[DDPBlockInfo, ...] @abstractmethod @torch.no_grad() @@ -73,14 +73,6 @@ def update_params( masked_blocked_search_directions: tuple[Tensor, ...], ) -> None: ... - @property - def global_blocked_params(self) -> tuple[Tensor, ...]: - return self._global_blocked_params - - @property - def distributor_selector(self) -> tuple[bool, ...]: - return self._distributor_selector - @property def local_grad_selector(self) -> tuple[bool, ...]: return self._local_grad_selector @@ -94,8 +86,8 @@ def local_masked_blocked_params(self) -> tuple[Tensor, ...]: return self._local_masked_blocked_params @property - def global_block_info_list(self) -> tuple[BlockInfo, ...]: - return self._global_block_info_list + def local_block_info_list(self) -> tuple[BlockInfo, ...]: + return self._local_block_info_list def _construct_composable_block_ids( self, @@ -258,18 +250,18 @@ def __init__( param_group: dict[str, Any], ) -> None: super().__init__(param_group) - self._construct_global_block_info_list() # Initialize selectors and local blocked (masked) parameters. self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._global_blocked_params ) self._distributor_selector: tuple[bool, ...] = self._local_grad_selector + self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params self._local_masked_blocked_params: tuple[Tensor, ...] = ( - self._global_blocked_params + self._local_blocked_params ) - self._local_blocked_params: tuple[Tensor, ...] = ( - self._local_masked_blocked_params + self._local_block_info_list: tuple[BlockInfo, ...] = ( + self._construct_local_block_info_list() ) @torch.no_grad() @@ -288,11 +280,12 @@ def update_params( masked_blocked_search_directions, ) - def _construct_global_block_info_list( + @torch.no_grad() + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" - self._global_block_info_list = tuple( + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_fsdp_distributor.py b/distributed_shampoo/utils/shampoo_fsdp_distributor.py index 852f598..b80d9fe 100644 --- a/distributed_shampoo/utils/shampoo_fsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_fsdp_distributor.py @@ -55,7 +55,6 @@ def __init__( self._global_num_blocks_per_split_param: tuple[int, ...] = () super().__init__(param_group) - self._construct_global_block_info_list() # Initialize selectors and local blocked (masked) parameters. self._local_grad_selector: tuple[bool, ...] = (True,) * len( @@ -66,6 +65,9 @@ def __init__( self._global_blocked_params ) self._local_blocked_params: tuple[Tensor, ...] = self._global_blocked_params + self._local_block_info_list: tuple[BlockInfo, ...] = ( + self._construct_local_block_info_list() + ) @torch.no_grad() def update_params( @@ -102,12 +104,13 @@ def _construct_composable_block_ids( """ return (param_index, f"rank_{rank}-block_{block_index}") - def _construct_global_block_info_list( + @torch.no_grad() + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" rank = dist.get_rank() - self._global_block_info_list = tuple( + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py index e63b359..47c7a2f 100644 --- a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py @@ -65,10 +65,10 @@ def _construct_composable_block_ids( return (param_index, f"rank_{rank}-block_{block_index}") @torch.no_grad() - def _construct_global_block_info_list( + def _construct_local_block_info_list( self, - ) -> None: - """Construct global block info list from param_group and num_blocks_within_param.""" + ) -> tuple[BlockInfo, ...]: + """Construct local block info list from param_group and num_blocks_within_param.""" rank = dist.get_rank() # Call `super()` instead of `self` as a performance optimization. @@ -77,7 +77,7 @@ def _construct_global_block_info_list( lambda p: p.to_local().numel() > 0, # type: ignore[arg-type] super()._get_params_or_grads(), ) - self._global_block_info_list = tuple( + return tuple( BlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index 68bc8b8..217ae9d 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -200,12 +200,13 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) - + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == comms_group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -216,6 +217,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -369,14 +373,15 @@ def _construct_composable_block_ids( """ return (param_index, f"rank_{rank}-block_{block_index}") + @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct global block info list from param_group and num_blocks_within_param.""" # Note that for HSDP, we want to get the rank within each sharded group for the block id. # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1) - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -575,7 +580,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = HSDPDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py index a25a683..7d7e0d4 100644 --- a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py @@ -186,12 +186,14 @@ def __init__( ) ) - self._construct_global_block_info_list(buffer_size_ranks) + global_block_info_list = self._construct_global_block_info_list( + buffer_size_ranks + ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( block_info.group_source_rank == comms_group_rank - for block_info in self._global_block_info_list + for block_info in global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( self._global_blocked_params, self._distributor_selector @@ -202,6 +204,9 @@ def __init__( self._local_grad_selector: tuple[bool, ...] = (True,) * len( self._local_blocked_params ) + self._local_block_info_list: tuple[DDPBlockInfo, ...] = compress_list( + global_block_info_list, self._distributor_selector + ) self._construct_distributed_buffers( buffer_size_ranks=buffer_size_ranks, @@ -375,7 +380,7 @@ def _construct_composable_block_ids( @torch.no_grad() def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] - ) -> None: + ) -> tuple[DDPBlockInfo, ...]: """Construct global block info list from param_group and num_blocks_within_param.""" # Call `super()` instead of `self` as a performance optimization. # This leads to O(1) instead of O(N) complexity to retrieve the parameters. @@ -387,7 +392,7 @@ def _construct_global_block_info_list( # Note that for HybridShard, we want to get the rank within each sharded group for the block id. # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. sharded_group_rank = self._hybrid_shard_device_mesh.get_local_rank(1) - self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + return tuple( DDPBlockInfo( param=param, composable_block_ids=self._construct_composable_block_ids( @@ -512,7 +517,7 @@ def _construct_distributed_buffers( self._global_dist_buffer = torch.zeros( total_buffer_size, dtype=torch.int8, - device=self._global_block_info_list[0].param.device, + device=self._global_blocked_params[0].device, ) local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) splitted_local_dist_buffers = HybridShardDistributor._split_local_dist_buffers( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 8f0fd99..43cc067 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -143,8 +143,6 @@ class AdagradPreconditionerList(PreconditionerList): state (Mapping[Tensor, Any]): Mapping containing optimizer state. block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter - is selected by the current Distributor. beta2 (float): Exponential moving average factor for Adam/RMSprop second moment state. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-10) @@ -158,7 +156,6 @@ def __init__( # type: ignore state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], - distributor_selector: tuple[bool, ...], beta2: float = 1.0, epsilon: float = 1e-10, use_bias_correction: bool = True, @@ -176,7 +173,7 @@ def __init__( # and do not explicitly store them as AdagradPreconditionerList attributes here. # This is because the optimizer state is defined per-parameter, but AdagradPreconditionerList is defined # across each parameter group (which includes multiple parameters). - preconditioner_list = [] + preconditioner_list: list[Tensor] = [] for block, block_info in zip(block_list, block_info_list, strict=True): param_index, block_index = block_info.composable_block_ids if block_index not in state[block_info.param]: @@ -198,17 +195,12 @@ def __init__( ) # Masked lists are the list of active preconditioners or values after filtering out gradients with None. - self._local_preconditioner_list: tuple[Tensor, ...] = compress_list( - preconditioner_list, distributor_selector - ) + self._local_preconditioner_list: tuple[Tensor, ...] = tuple(preconditioner_list) self._masked_preconditioner_list: tuple[Tensor, ...] = ( self._local_preconditioner_list ) - # Construct lists of dims, bytes, and numels for logging purposes. - self._dims_list: tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) + # Construct lists of numels and bytes for logging purposes. self._numel_list: tuple[int, ...] = tuple( preconditioner.numel() for preconditioner in self._local_preconditioner_list ) @@ -363,8 +355,6 @@ class BaseShampooPreconditionerList( state (Mapping[Tensor, Any]): Mapping containing optimizer state. block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter - is selected by the current Distributor. preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) @@ -384,7 +374,6 @@ def __init__( # type: ignore state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], - distributor_selector: tuple[bool, ...], preconditioner_config: PreconditionerConfig, beta2: float = 1.0, epsilon: float = 1e-12, @@ -416,7 +405,6 @@ def __init__( self._initialize_state_lists( block_list=block_list, kronecker_factors_list=kronecker_factors_list, - distributor_selector=distributor_selector, ) def _create_base_kronecker_factors( @@ -663,16 +651,14 @@ def _initialize_state_lists( self, block_list: tuple[Tensor, ...], kronecker_factors_list: list[ShampooKroneckerFactorsListType], - distributor_selector: tuple[bool, ...], ) -> None: # Initialize local lists. - local_block_list = compress_list(block_list, distributor_selector) self._local_kronecker_factors_list: tuple[ ShampooKroneckerFactorsListType, ..., - ] = compress_list(kronecker_factors_list, distributor_selector) + ] = tuple(kronecker_factors_list) self._local_order_list: tuple[int, ...] = tuple( - block.dim() for block in local_block_list + block.dim() for block in block_list ) self._local_root_list: tuple[int, ...] = self._get_inverse_roots_from_override( self._inv_root_override, @@ -689,9 +675,6 @@ def _initialize_state_lists( # Construct lists of bytes and numels for logging purposes. # NOTE: These lists are constructed across all blocked parameters. - self._dims_list: tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) self._numel_list: tuple[int, ...] = tuple( sum(2 * dim**2 for dim in dims) for dims in self._dims_list ) @@ -699,7 +682,7 @@ def _initialize_state_lists( numel * (get_dtype_size(self._factor_matrix_dtype) + get_dtype_size(block.dtype)) // 2 - for numel, block in zip(self._numel_list, local_block_list, strict=True) + for numel, block in zip(self._numel_list, block_list, strict=True) ) def compress_preconditioner_list( diff --git a/distributed_shampoo/utils/tests/shampoo_distributor_test.py b/distributed_shampoo/utils/tests/shampoo_distributor_test.py index 5825930..ff23203 100644 --- a/distributed_shampoo/utils/tests/shampoo_distributor_test.py +++ b/distributed_shampoo/utils/tests/shampoo_distributor_test.py @@ -90,14 +90,6 @@ def test_update_params(self) -> None: actual_masked_blocked_params, expected_masked_blocked_params ) - def test_distributor_selector(self) -> None: - # Two blocks from the linear layer, and one block from the bias layer. - expected_distributor_selector = (True, True, True) - self.assertEqual( - self._distributor.distributor_selector, - expected_distributor_selector, - ) - def test_local_grad_selector(self) -> None: # Explicitly disable the gradient of the bias layer and call merge_and_block_gradients() # to update the local gradient selector for the bias layer (i.e., 3rd block). @@ -111,17 +103,6 @@ def test_local_grad_selector(self) -> None: expected_local_grad_selector, ) - def test_global_blocked_params(self) -> None: - expected_global_params = ( - torch.zeros(5, 5, dtype=torch.float), - torch.zeros(5, 5, dtype=torch.float), - torch.zeros(5, dtype=torch.float), - ) - torch.testing.assert_close( - self._distributor.global_blocked_params, - expected_global_params, - ) - def test_local_blocked_params(self) -> None: # In Distributor, because there is no global vs. local boundary concept, # global and local blocked params are always identical. @@ -135,8 +116,8 @@ def test_local_blocked_params(self) -> None: expected_local_params, ) - def test_global_block_info_list(self) -> None: - expected_global_block_info_list = ( + def test_local_block_info_list(self) -> None: + expected_local_block_info_list = ( BlockInfo( param=self._model.linear_layers[0].weight, composable_block_ids=(0, "block_0"), @@ -151,8 +132,8 @@ def test_global_block_info_list(self) -> None: ), ) self.assertEqual( - self._distributor.global_block_info_list, - expected_global_block_info_list, + self._distributor.local_block_info_list, + expected_local_block_info_list, ) def test_merge_and_block_gradients(self) -> None: diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 5952fd8..2b375d5 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -167,11 +167,10 @@ def test_compress_preconditioner_list(self) -> None: class AdagradPreconditionerListTest(PreconditionerListTest): def _instantiate_block_list(self) -> tuple[Tensor, ...]: # Because maximum_preconditioner_dim = 2, self._params[0] forms a block by itself, - # self._params[1] are split into two blocks, and self._params[2] forms a block by itself. + # and self._params[1] are split into two blocks. return ( self._params[0], *torch.split(self._params[1], 2, dim=0), - self._params[2], ) def _instantiate_preconditioner_list( @@ -182,7 +181,6 @@ def _instantiate_preconditioner_list( block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, **kwargs, ) @@ -190,16 +188,13 @@ def setUp(self) -> None: self._params = ( torch.tensor([1.0, 2.0]), torch.arange(6, dtype=torch.float).reshape(3, 2), - # Following param will not be used due to the distributor selector below. - torch.tensor([torch.nan, torch.nan]), ) self._state = { # type: ignore[var-annotated] self._params[0]: {}, self._params[1]: {}, - self._params[2]: {}, } # Because maximum_preconditioner_dim = 2, self._params[0] forms a block by itself, - # self._params[1] are split into two blocks, and self._params[2] forms a block by itself. + # and self._params[1] are split into two blocks. self._block_info_list = ( BlockInfo( param=self._params[0], @@ -213,13 +208,7 @@ def setUp(self) -> None: param=self._params[1], composable_block_ids=(1, "block_1"), ), - BlockInfo( - param=self._params[2], - composable_block_ids=(2, "block_0"), - ), ) - # Ignores the last block, which is self._params[2] itself. - self._distributor_selector = (True, True, True, False) super().setUp() def test_update_preconditioners_and_precondition(self) -> None: @@ -295,7 +284,6 @@ def test_abstract_methods(self) -> None: composable_block_ids=(0, "block_0"), ), ), - distributor_selector=(True,), preconditioner_config=DefaultShampooConfig, beta2=1.0, ) @@ -478,7 +466,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] ) @@ -722,7 +709,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_list=self._block_list, state=self._state, block_info_list=self._block_info_list, - distributor_selector=self._distributor_selector, factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] )