From a7b27f235ee6b5ebcf4cce558b2cdf7757b20585 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Thu, 19 Dec 2024 12:00:52 -0800 Subject: [PATCH] Remove track_root_inv_residuals (#66) Summary: Deprecate `track_root_inv_residual` flag. Differential Revision: D67468834 --- distributed_shampoo/distributed_shampoo.py | 66 +------------------ .../examples/ddp_cifar10_example.py | 1 - .../examples/default_cifar10_example.py | 1 - .../examples/fsdp_cifar10_example.py | 1 - .../examples/fully_shard_cifar10_example.py | 1 - distributed_shampoo/examples/trainer_utils.py | 7 -- .../tests/distributed_shampoo_test.py | 63 ------------------ .../utils/shampoo_preconditioner_list.py | 54 +-------------- .../tests/shampoo_preconditioner_list_test.py | 42 ------------ 9 files changed, 2 insertions(+), 234 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 7f7c9bc..a922d61 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -91,7 +91,7 @@ ) from distributed_shampoo.utils.shampoo_utils import compress_list -from matrix_functions_types import EigenConfig, RootInvConfig +from matrix_functions_types import EigenConfig from torch.optim.optimizer import ParamsT, StateDict logger: logging.Logger = logging.getLogger(__name__) @@ -285,8 +285,6 @@ class DistributedShampoo(torch.optim.Optimizer): to different distributed training frameworks, such as distributed-data parallel (DDP) training. Based on the configuration, determines which version of Shampoo to use. (Default: None) preconditioner_dtype (torch.dtype): Data type for preconditioner. (Default: None) - track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. - (Default: False) preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner. If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. @@ -317,7 +315,6 @@ def __init__( shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None, distributed_config: DistributedConfig | None = None, preconditioner_dtype: torch.dtype = torch.float, - track_root_inv_residuals: bool = False, preconditioner_config: PreconditionerConfig = DefaultShampooConfig, ) -> None: # Hyperparameter checks. @@ -373,8 +370,6 @@ def __init__( raise ValueError( f"Invalid exponent override: {inv_root_override}. Must be >= 0." ) - if track_root_inv_residuals: - logger.setLevel(logging.DEBUG) # Provide warning/error for start_preconditioning_step. if start_preconditioning_step == -1: @@ -404,13 +399,6 @@ def __init__( amortized_computation_config = ( preconditioner_config.amortized_computation_config ) - if ( - not isinstance(amortized_computation_config, RootInvConfig) - ) and track_root_inv_residuals: - raise ValueError( - f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig." - ) - # Set exponent multiplier if this is not provided. if ( isinstance(amortized_computation_config, EigenConfig) @@ -448,9 +436,6 @@ def __init__( }, ) - # Initialize non-group-related fields. - self._track_root_inv_residuals = track_root_inv_residuals - # Initialize list containing group state dictionaries. self._per_group_state_lists: list[dict[str, Any]] = [ {} for _ in self.param_groups @@ -779,53 +764,6 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non state_lists[DISTRIBUTOR].local_grad_selector, ) - @torch.no_grad() - @torch.compiler.disable - def _compute_and_log_root_inverse_residuals( - self, - ) -> None: - """Compute root inverse residuals over all preconditioners. - - Uses infinity norm to evaluate residuals and errors. - """ - - # Compute relative errors/residuals for each group. - for (group_index, group), state_lists in zip( - enumerate(self.param_groups), self._per_group_state_lists, strict=True - ): - if group[PRECONDITIONER_DTYPE] == torch.float64: - expected_relative_error = 1e-7 - elif group[PRECONDITIONER_DTYPE] == torch.float32: - expected_relative_error = 1e-3 - else: - logger.warning( - "Expected relative error/residual not supported for precision lower than float32." - ) - continue - - relative_errors, relative_residuals = map( - torch.stack, - state_lists[ - SHAMPOO_PRECONDITIONER_LIST - ].compute_root_inverse_residuals(), - ) - quantiles = torch.as_tensor( - [0, 0.25, 0.5, 0.75, 1], - device=relative_errors.device, - dtype=relative_errors.dtype, - ) - logger.debug(f"Group Index: {group_index}") - logger.debug(f"Expect Relative Error <= {expected_relative_error}") - logger.debug( - f"Relative Error (||X - X_hat||_inf / ||X||_inf) Average: {torch.mean(relative_errors)}, " - f"Quantiles [0, 25, 50, 75, 100]: {torch.quantile(relative_errors, quantiles, interpolation='nearest')}" - ) - logger.debug( - f"Relative Residual (||X_hat^-r - A||_inf / ||A||_inf) Average: {torch.mean(relative_residuals)}, " - "Quantiles [0, 25, 50, 75, 100]: " - f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}" - ) - @torch.no_grad() @torch.compiler.disable def _precondition_and_grafting( @@ -900,8 +838,6 @@ def _update_preconditioners( step=step, perform_amortized_computation=perform_amortized_computation, ) - if perform_amortized_computation and self._track_root_inv_residuals: - self._compute_and_log_root_inverse_residuals() if grafting_config_not_none: state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners( masked_grad_list=state_lists[MASKED_BLOCKED_GRADS], diff --git a/distributed_shampoo/examples/ddp_cifar10_example.py b/distributed_shampoo/examples/ddp_cifar10_example.py index 65360aa..2a5a32d 100644 --- a/distributed_shampoo/examples/ddp_cifar10_example.py +++ b/distributed_shampoo/examples/ddp_cifar10_example.py @@ -123,7 +123,6 @@ communicate_params=args.communicate_params, ), preconditioner_dtype=args.preconditioner_dtype, - track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/default_cifar10_example.py b/distributed_shampoo/examples/default_cifar10_example.py index 8fcfbc6..add8575 100644 --- a/distributed_shampoo/examples/default_cifar10_example.py +++ b/distributed_shampoo/examples/default_cifar10_example.py @@ -133,7 +133,6 @@ def train_default_model( use_merge_dims=args.use_merge_dims, distributed_config=None, preconditioner_dtype=args.preconditioner_dtype, - track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/fsdp_cifar10_example.py b/distributed_shampoo/examples/fsdp_cifar10_example.py index cfedf46..34a73d4 100644 --- a/distributed_shampoo/examples/fsdp_cifar10_example.py +++ b/distributed_shampoo/examples/fsdp_cifar10_example.py @@ -115,7 +115,6 @@ param_to_metadata=compile_fsdp_parameter_metadata(model), ), preconditioner_dtype=args.preconditioner_dtype, - track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/fully_shard_cifar10_example.py b/distributed_shampoo/examples/fully_shard_cifar10_example.py index 7cae037..9d53784 100644 --- a/distributed_shampoo/examples/fully_shard_cifar10_example.py +++ b/distributed_shampoo/examples/fully_shard_cifar10_example.py @@ -135,7 +135,6 @@ def create_model_and_optimizer_and_loss_fn(args, device): use_merge_dims=args.use_merge_dims, distributed_config=FullyShardShampooConfig(), preconditioner_dtype=args.preconditioner_dtype, - track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) return model, optimizer, loss_function diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index eb84fd6..8b00f70 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -194,11 +194,6 @@ def get_args(): action="store_true", help="Use merge dims for Shampoo.", ) - parser.add_argument( - "--track-root-inv-residuals", - action="store_true", - help="Use debug mode for examining root inverse residuals.", - ) parser.add_argument( "--preconditioner-computation-type", type=lambda t: enum_type_parse(t, PreconditionerComputationType), @@ -387,7 +382,6 @@ def instantiate_optimizer( use_merge_dims: bool, distributed_config: DistributedConfig | None, preconditioner_dtype: DType, - track_root_inv_residuals: bool, preconditioner_computation_type: PreconditionerComputationType, ) -> torch.optim.Optimizer: if optimizer_type == OptimizerType.SGD: @@ -440,7 +434,6 @@ def instantiate_optimizer( use_merge_dims=use_merge_dims, distributed_config=distributed_config, preconditioner_dtype=preconditioner_dtype.value, - track_root_inv_residuals=track_root_inv_residuals, preconditioner_config=instantiate_preconditioner_config( preconditioner_computation_type ), diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index b861a61..455743d 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -22,7 +22,6 @@ from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, DDPShampooConfig, - DefaultEigenvalueCorrectedShampooConfig, DefaultShampooConfig, DistributedConfig, GraftingConfig, @@ -225,21 +224,6 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: ], ) - def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> None: - with self.assertRaisesRegex( - ValueError, - re.escape( - "track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig." - ), - ): - DistributedShampoo( - self._model.parameters(), - lr=0.01, - start_preconditioning_step=1, - track_root_inv_residuals=True, - preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, - ) - class DistributedShampooTest(unittest.TestCase): def setUp(self) -> None: @@ -629,53 +613,6 @@ def test_load_distributed_state_dict_with_missing_param(self) -> None: ) -class DistributedShampooTrackRootInvResidualsTest(unittest.TestCase): - def _get_track_root_inverse_residuals_output(self, dtype: torch.dtype) -> list[str]: - # Create a model and a DistributedShampoo optimizer with enabled track_root_inv_residuals and corresponding dtype. - # The dtype of the model and the optimizer are the same. - model = nn.Sequential(nn.Linear(2, 1, bias=False)) - model[0].weight.data = torch.tensor([1.0, 2.0], dtype=dtype) - optimizer = DistributedShampoo( - params=model.parameters(), - precondition_frequency=2, - start_preconditioning_step=2, - preconditioner_dtype=dtype, - track_root_inv_residuals=True, - ) - - # Run two steps of the optimizer to compute the root inverse residuals. - # Because precondition_frequency and start_preconditioning_step are both 2, there should be one call of - # _compute_and_log_root_inverse_residuals(). - with self.assertLogs(level="DEBUG") as cm: - model[0].weight.grad = torch.tensor([1.0, 0.0], dtype=dtype) - optimizer.step() - model[0].weight.grad = torch.tensor([0.0, 1.0], dtype=dtype) - optimizer.step() - return [r.msg for r in cm.records] - - def test_compute_and_log_root_inverse_residuals(self) -> None: - # Test the cases that tracking root inverse residuals support both float32 and float64. - for dtype, expected_relative_error in [ - (torch.float32, 1e-3), - (torch.float64, 1e-7), - ]: - with self.subTest(dtype=dtype): - msgs = self._get_track_root_inverse_residuals_output(dtype=dtype) - self.assertIn("Group Index: 0", msgs) - self.assertIn( - f"Expect Relative Error <= {expected_relative_error}", msgs - ) - - # Test the case that tracking root inverse residuals does not support float16. - msgs = self._get_track_root_inverse_residuals_output(dtype=torch.float16) - self.assertEqual( - msgs, - [ - "Expected relative error/residual not supported for precision lower than float32." - ], - ) - - class DistributedShampooNoneGradTest(unittest.TestCase): def setUp(self) -> None: self._model = nn.Sequential( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index d0a0597..8f0fd99 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -24,12 +24,7 @@ ) from distributed_shampoo.utils.shampoo_block_info import BlockInfo from distributed_shampoo.utils.shampoo_utils import compress_list, get_dtype_size -from matrix_functions import ( - check_diagonal, - compute_matrix_root_inverse_residuals, - matrix_eigenvectors, - matrix_inverse_root, -) +from matrix_functions import check_diagonal, matrix_eigenvectors, matrix_inverse_root from matrix_functions_types import EigenvectorConfig, RootInvConfig from optimizer_modules import OptimizerModule @@ -924,53 +919,6 @@ def _amortized_computation(self) -> None: ) inv_factor_matrix.copy_(computed_inv_factor_matrix) - @torch.compiler.disable - def compute_root_inverse_residuals( - self, - ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: - root_inv_config = cast( - RootInvConfig, - self._preconditioner_config.amortized_computation_config, - ) - relative_errors = [] - relative_residuals = [] - - for kronecker_factors, root in zip( - self._masked_kronecker_factors_list, - self._masked_root_list, - strict=True, - ): - for factor_matrix, inv_factor_matrix in zip( - kronecker_factors.factor_matrices, - kronecker_factors.inv_factor_matrices, - strict=True, - ): - bias_corrected_factor_matrix = factor_matrix / self._bias_correction2 - ( - relative_error, - relative_residual, - ) = compute_matrix_root_inverse_residuals( - A=bias_corrected_factor_matrix, - X_hat=inv_factor_matrix, - root=Fraction( - root - / getattr( - root_inv_config, - "exponent_multiplier", - 1, - ) - ), - epsilon=self._epsilon, - root_inv_config=root_inv_config, - ) - relative_errors.append(relative_error) - relative_residuals.append(relative_residual) - - return ( - tuple(relative_errors), - tuple(relative_residuals), - ) - class EigenvalueCorrectedShampooPreconditionerList( BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList] diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index dd41ae5..15d11ed 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -703,48 +703,6 @@ def test_inverse_roots_from_override( test_inverse_roots_from_override(inv_root_override=2) test_inverse_roots_from_override(inv_root_override=[2, 2, 2]) - def test_compute_root_inverse_residuals(self) -> None: - """ - Create a factor matrix of size 2x2 by updating preconditioners in two steps: - Step 1. G1 = [1, 0]^T - Step 2. G2 = [0, 1]^T - - L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] - """ - preconditioner_list = ShampooPreconditionerList( - block_list=(self._params[0],), - state=self._state, - block_info_list=(self._block_info_list[0],), - distributor_selector=(self._distributor_selector[0],), - preconditioner_config=DefaultShampooConfig, - epsilon=0.0, - ) - - masked_grad_list1 = (torch.tensor([1.0, 0.0]),) - masked_grad_list2 = (torch.tensor([0.0, 1.0]),) - preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list1, - step=torch.tensor(1), - perform_amortized_computation=False, - ) - preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list2, - step=torch.tensor(2), - perform_amortized_computation=True, - ) - - # Expect no relative errors and residuals because L is a diagonal matrix. - ( - relative_errors, - relative_residuals, - ) = preconditioner_list.compute_root_inverse_residuals() - - expected_relative_errors = (torch.tensor(0.0),) - expected_relative_residuals = (torch.tensor(0.0),) - - self.assertTupleEqual(relative_errors, expected_relative_errors) - self.assertTupleEqual(relative_residuals, expected_relative_residuals) - class EigenvalueCorrectedShampooPreconditionerListTest( AbstractTest.BaseShampooPreconditionerListTest