Skip to content

Commit

Permalink
Remove track_root_inv_residuals (#66)
Browse files Browse the repository at this point in the history
Summary:

Deprecate `track_root_inv_residual` flag.

Differential Revision: D67468834
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 19, 2024
1 parent ae91e28 commit 9293a16
Show file tree
Hide file tree
Showing 10 changed files with 9 additions and 238 deletions.
66 changes: 1 addition & 65 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import EigenConfig, RootInvConfig
from matrix_functions_types import EigenConfig
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -285,8 +285,6 @@ class DistributedShampoo(torch.optim.Optimizer):
to different distributed training frameworks, such as distributed-data parallel (DDP) training.
Based on the configuration, determines which version of Shampoo to use. (Default: None)
preconditioner_dtype (torch.dtype): Data type for preconditioner. (Default: None)
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation.
If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
Expand Down Expand Up @@ -317,7 +315,6 @@ def __init__(
shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None,
distributed_config: DistributedConfig | None = None,
preconditioner_dtype: torch.dtype = torch.float,
track_root_inv_residuals: bool = False,
preconditioner_config: PreconditionerConfig = DefaultShampooConfig,
) -> None:
# Hyperparameter checks.
Expand Down Expand Up @@ -373,8 +370,6 @@ def __init__(
raise ValueError(
f"Invalid exponent override: {inv_root_override}. Must be >= 0."
)
if track_root_inv_residuals:
logger.setLevel(logging.DEBUG)

# Provide warning/error for start_preconditioning_step.
if start_preconditioning_step == -1:
Expand Down Expand Up @@ -404,13 +399,6 @@ def __init__(
amortized_computation_config = (
preconditioner_config.amortized_computation_config
)
if (
not isinstance(amortized_computation_config, RootInvConfig)
) and track_root_inv_residuals:
raise ValueError(
f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig."
)

# Set exponent multiplier if this is not provided.
if (
isinstance(amortized_computation_config, EigenConfig)
Expand Down Expand Up @@ -448,9 +436,6 @@ def __init__(
},
)

# Initialize non-group-related fields.
self._track_root_inv_residuals = track_root_inv_residuals

# Initialize list containing group state dictionaries.
self._per_group_state_lists: list[dict[str, Any]] = [
{} for _ in self.param_groups
Expand Down Expand Up @@ -779,53 +764,6 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non
state_lists[DISTRIBUTOR].local_grad_selector,
)

@torch.no_grad()
@torch.compiler.disable
def _compute_and_log_root_inverse_residuals(
self,
) -> None:
"""Compute root inverse residuals over all preconditioners.
Uses infinity norm to evaluate residuals and errors.
"""

# Compute relative errors/residuals for each group.
for (group_index, group), state_lists in zip(
enumerate(self.param_groups), self._per_group_state_lists, strict=True
):
if group[PRECONDITIONER_DTYPE] == torch.float64:
expected_relative_error = 1e-7
elif group[PRECONDITIONER_DTYPE] == torch.float32:
expected_relative_error = 1e-3
else:
logger.warning(
"Expected relative error/residual not supported for precision lower than float32."
)
continue

relative_errors, relative_residuals = map(
torch.stack,
state_lists[
SHAMPOO_PRECONDITIONER_LIST
].compute_root_inverse_residuals(),
)
quantiles = torch.as_tensor(
[0, 0.25, 0.5, 0.75, 1],
device=relative_errors.device,
dtype=relative_errors.dtype,
)
logger.debug(f"Group Index: {group_index}")
logger.debug(f"Expect Relative Error <= {expected_relative_error}")
logger.debug(
f"Relative Error (||X - X_hat||_inf / ||X||_inf) Average: {torch.mean(relative_errors)}, "
f"Quantiles [0, 25, 50, 75, 100]: {torch.quantile(relative_errors, quantiles, interpolation='nearest')}"
)
logger.debug(
f"Relative Residual (||X_hat^-r - A||_inf / ||A||_inf) Average: {torch.mean(relative_residuals)}, "
"Quantiles [0, 25, 50, 75, 100]: "
f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}"
)

@torch.no_grad()
@torch.compiler.disable
def _precondition_and_grafting(
Expand Down Expand Up @@ -900,8 +838,6 @@ def _update_preconditioners(
step=step,
perform_amortized_computation=perform_amortized_computation,
)
if perform_amortized_computation and self._track_root_inv_residuals:
self._compute_and_log_root_inverse_residuals()
if grafting_config_not_none:
state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners(
masked_grad_list=state_lists[MASKED_BLOCKED_GRADS],
Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
communicate_params=args.communicate_params,
),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/default_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def train_default_model(
use_merge_dims=args.use_merge_dims,
distributed_config=None,
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/fsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
param_to_metadata=compile_fsdp_parameter_metadata(model),
),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def create_model_and_optimizer_and_loss_fn(args, device):
use_merge_dims=args.use_merge_dims,
distributed_config=FullyShardShampooConfig(),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)
return model, optimizer, loss_function
Expand Down
7 changes: 0 additions & 7 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,6 @@ def get_args():
action="store_true",
help="Use merge dims for Shampoo.",
)
parser.add_argument(
"--track-root-inv-residuals",
action="store_true",
help="Use debug mode for examining root inverse residuals.",
)
parser.add_argument(
"--preconditioner-computation-type",
type=lambda t: enum_type_parse(t, PreconditionerComputationType),
Expand Down Expand Up @@ -387,7 +382,6 @@ def instantiate_optimizer(
use_merge_dims: bool,
distributed_config: DistributedConfig | None,
preconditioner_dtype: DType,
track_root_inv_residuals: bool,
preconditioner_computation_type: PreconditionerComputationType,
) -> torch.optim.Optimizer:
if optimizer_type == OptimizerType.SGD:
Expand Down Expand Up @@ -440,7 +434,6 @@ def instantiate_optimizer(
use_merge_dims=use_merge_dims,
distributed_config=distributed_config,
preconditioner_dtype=preconditioner_dtype.value,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_config=instantiate_preconditioner_config(
preconditioner_computation_type
),
Expand Down
63 changes: 0 additions & 63 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
DDPShampooConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DistributedConfig,
GraftingConfig,
Expand Down Expand Up @@ -225,21 +224,6 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None:
],
)

def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> None:
with self.assertRaisesRegex(
ValueError,
re.escape(
"track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig."
),
):
DistributedShampoo(
self._model.parameters(),
lr=0.01,
start_preconditioning_step=1,
track_root_inv_residuals=True,
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
)


class DistributedShampooTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -629,53 +613,6 @@ def test_load_distributed_state_dict_with_missing_param(self) -> None:
)


class DistributedShampooTrackRootInvResidualsTest(unittest.TestCase):
def _get_track_root_inverse_residuals_output(self, dtype: torch.dtype) -> list[str]:
# Create a model and a DistributedShampoo optimizer with enabled track_root_inv_residuals and corresponding dtype.
# The dtype of the model and the optimizer are the same.
model = nn.Sequential(nn.Linear(2, 1, bias=False))
model[0].weight.data = torch.tensor([1.0, 2.0], dtype=dtype)
optimizer = DistributedShampoo(
params=model.parameters(),
precondition_frequency=2,
start_preconditioning_step=2,
preconditioner_dtype=dtype,
track_root_inv_residuals=True,
)

# Run two steps of the optimizer to compute the root inverse residuals.
# Because precondition_frequency and start_preconditioning_step are both 2, there should be one call of
# _compute_and_log_root_inverse_residuals().
with self.assertLogs(level="DEBUG") as cm:
model[0].weight.grad = torch.tensor([1.0, 0.0], dtype=dtype)
optimizer.step()
model[0].weight.grad = torch.tensor([0.0, 1.0], dtype=dtype)
optimizer.step()
return [r.msg for r in cm.records]

def test_compute_and_log_root_inverse_residuals(self) -> None:
# Test the cases that tracking root inverse residuals support both float32 and float64.
for dtype, expected_relative_error in [
(torch.float32, 1e-3),
(torch.float64, 1e-7),
]:
with self.subTest(dtype=dtype):
msgs = self._get_track_root_inverse_residuals_output(dtype=dtype)
self.assertIn("Group Index: 0", msgs)
self.assertIn(
f"Expect Relative Error <= {expected_relative_error}", msgs
)

# Test the case that tracking root inverse residuals does not support float16.
msgs = self._get_track_root_inverse_residuals_output(dtype=torch.float16)
self.assertEqual(
msgs,
[
"Expected relative error/residual not supported for precision lower than float32."
],
)


class DistributedShampooNoneGradTest(unittest.TestCase):
def setUp(self) -> None:
self._model = nn.Sequential(
Expand Down
11 changes: 7 additions & 4 deletions distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#!/usr/bin/env python3

from functools import partial
from unittest import mock

import torch
Expand Down Expand Up @@ -42,20 +43,22 @@ def _verify_deivce_mesh(self, device_mesh: DeviceMesh) -> None:
(shard_mesh.get_group(), replicate_mesh.get_group()),
)

@with_comms # type: ignore
@with_comms
def test_get_device_mesh(self) -> None:
mesh = tuple(
map(
tuple, # type: ignore
# Some type-checkers are not able to recognize the `tuple` below as a function. Use `partial` here to explicitly make a Callable for those type-checkers.
partial(tuple),
torch.tensor(range(self.world_size))
.view(-1, self.world_size // 2)
.tolist(),
)
)

device_type = getattr(self, "device_type", "cpu")
self._verify_deivce_mesh(
device_mesh=get_device_mesh(
device_type=self.device_type, # type: ignore
device_type=device_type,
mesh=mesh,
mesh_dim_names=("replicate", "shard"),
)
Expand All @@ -69,7 +72,7 @@ def test_get_device_mesh(self) -> None:
"__init__",
) as mock_device_mesh_init:
device_mesh = get_device_mesh(
device_type=self.device_type, # type: ignore
device_type=device_type,
mesh=mesh,
mesh_dim_names=("replicate", "shard"),
)
Expand Down
54 changes: 1 addition & 53 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
)
from distributed_shampoo.utils.shampoo_block_info import BlockInfo
from distributed_shampoo.utils.shampoo_utils import compress_list, get_dtype_size
from matrix_functions import (
check_diagonal,
compute_matrix_root_inverse_residuals,
matrix_eigenvectors,
matrix_inverse_root,
)
from matrix_functions import check_diagonal, matrix_eigenvectors, matrix_inverse_root

from matrix_functions_types import EigenvectorConfig, RootInvConfig
from optimizer_modules import OptimizerModule
Expand Down Expand Up @@ -924,53 +919,6 @@ def _amortized_computation(self) -> None:
)
inv_factor_matrix.copy_(computed_inv_factor_matrix)

@torch.compiler.disable
def compute_root_inverse_residuals(
self,
) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]:
root_inv_config = cast(
RootInvConfig,
self._preconditioner_config.amortized_computation_config,
)
relative_errors = []
relative_residuals = []

for kronecker_factors, root in zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
strict=True,
):
for factor_matrix, inv_factor_matrix in zip(
kronecker_factors.factor_matrices,
kronecker_factors.inv_factor_matrices,
strict=True,
):
bias_corrected_factor_matrix = factor_matrix / self._bias_correction2
(
relative_error,
relative_residual,
) = compute_matrix_root_inverse_residuals(
A=bias_corrected_factor_matrix,
X_hat=inv_factor_matrix,
root=Fraction(
root
/ getattr(
root_inv_config,
"exponent_multiplier",
1,
)
),
epsilon=self._epsilon,
root_inv_config=root_inv_config,
)
relative_errors.append(relative_error)
relative_residuals.append(relative_residual)

return (
tuple(relative_errors),
tuple(relative_residuals),
)


class EigenvalueCorrectedShampooPreconditionerList(
BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList]
Expand Down
Loading

0 comments on commit 9293a16

Please sign in to comment.