Skip to content

Commit

Permalink
Remove track_root_inv_residuals (#66)
Browse files Browse the repository at this point in the history
Summary:

Deprecate `track_root_inv_residual` flag.

Differential Revision: D67468834
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 19, 2024
1 parent ae91e28 commit a7b27f2
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 234 deletions.
66 changes: 1 addition & 65 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import EigenConfig, RootInvConfig
from matrix_functions_types import EigenConfig
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -285,8 +285,6 @@ class DistributedShampoo(torch.optim.Optimizer):
to different distributed training frameworks, such as distributed-data parallel (DDP) training.
Based on the configuration, determines which version of Shampoo to use. (Default: None)
preconditioner_dtype (torch.dtype): Data type for preconditioner. (Default: None)
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation.
If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
Expand Down Expand Up @@ -317,7 +315,6 @@ def __init__(
shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None,
distributed_config: DistributedConfig | None = None,
preconditioner_dtype: torch.dtype = torch.float,
track_root_inv_residuals: bool = False,
preconditioner_config: PreconditionerConfig = DefaultShampooConfig,
) -> None:
# Hyperparameter checks.
Expand Down Expand Up @@ -373,8 +370,6 @@ def __init__(
raise ValueError(
f"Invalid exponent override: {inv_root_override}. Must be >= 0."
)
if track_root_inv_residuals:
logger.setLevel(logging.DEBUG)

# Provide warning/error for start_preconditioning_step.
if start_preconditioning_step == -1:
Expand Down Expand Up @@ -404,13 +399,6 @@ def __init__(
amortized_computation_config = (
preconditioner_config.amortized_computation_config
)
if (
not isinstance(amortized_computation_config, RootInvConfig)
) and track_root_inv_residuals:
raise ValueError(
f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig."
)

# Set exponent multiplier if this is not provided.
if (
isinstance(amortized_computation_config, EigenConfig)
Expand Down Expand Up @@ -448,9 +436,6 @@ def __init__(
},
)

# Initialize non-group-related fields.
self._track_root_inv_residuals = track_root_inv_residuals

# Initialize list containing group state dictionaries.
self._per_group_state_lists: list[dict[str, Any]] = [
{} for _ in self.param_groups
Expand Down Expand Up @@ -779,53 +764,6 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non
state_lists[DISTRIBUTOR].local_grad_selector,
)

@torch.no_grad()
@torch.compiler.disable
def _compute_and_log_root_inverse_residuals(
self,
) -> None:
"""Compute root inverse residuals over all preconditioners.
Uses infinity norm to evaluate residuals and errors.
"""

# Compute relative errors/residuals for each group.
for (group_index, group), state_lists in zip(
enumerate(self.param_groups), self._per_group_state_lists, strict=True
):
if group[PRECONDITIONER_DTYPE] == torch.float64:
expected_relative_error = 1e-7
elif group[PRECONDITIONER_DTYPE] == torch.float32:
expected_relative_error = 1e-3
else:
logger.warning(
"Expected relative error/residual not supported for precision lower than float32."
)
continue

relative_errors, relative_residuals = map(
torch.stack,
state_lists[
SHAMPOO_PRECONDITIONER_LIST
].compute_root_inverse_residuals(),
)
quantiles = torch.as_tensor(
[0, 0.25, 0.5, 0.75, 1],
device=relative_errors.device,
dtype=relative_errors.dtype,
)
logger.debug(f"Group Index: {group_index}")
logger.debug(f"Expect Relative Error <= {expected_relative_error}")
logger.debug(
f"Relative Error (||X - X_hat||_inf / ||X||_inf) Average: {torch.mean(relative_errors)}, "
f"Quantiles [0, 25, 50, 75, 100]: {torch.quantile(relative_errors, quantiles, interpolation='nearest')}"
)
logger.debug(
f"Relative Residual (||X_hat^-r - A||_inf / ||A||_inf) Average: {torch.mean(relative_residuals)}, "
"Quantiles [0, 25, 50, 75, 100]: "
f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}"
)

@torch.no_grad()
@torch.compiler.disable
def _precondition_and_grafting(
Expand Down Expand Up @@ -900,8 +838,6 @@ def _update_preconditioners(
step=step,
perform_amortized_computation=perform_amortized_computation,
)
if perform_amortized_computation and self._track_root_inv_residuals:
self._compute_and_log_root_inverse_residuals()
if grafting_config_not_none:
state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners(
masked_grad_list=state_lists[MASKED_BLOCKED_GRADS],
Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
communicate_params=args.communicate_params,
),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/default_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def train_default_model(
use_merge_dims=args.use_merge_dims,
distributed_config=None,
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
1 change: 0 additions & 1 deletion distributed_shampoo/examples/fsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
param_to_metadata=compile_fsdp_parameter_metadata(model),
),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def create_model_and_optimizer_and_loss_fn(args, device):
use_merge_dims=args.use_merge_dims,
distributed_config=FullyShardShampooConfig(),
preconditioner_dtype=args.preconditioner_dtype,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)
return model, optimizer, loss_function
Expand Down
7 changes: 0 additions & 7 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,6 @@ def get_args():
action="store_true",
help="Use merge dims for Shampoo.",
)
parser.add_argument(
"--track-root-inv-residuals",
action="store_true",
help="Use debug mode for examining root inverse residuals.",
)
parser.add_argument(
"--preconditioner-computation-type",
type=lambda t: enum_type_parse(t, PreconditionerComputationType),
Expand Down Expand Up @@ -387,7 +382,6 @@ def instantiate_optimizer(
use_merge_dims: bool,
distributed_config: DistributedConfig | None,
preconditioner_dtype: DType,
track_root_inv_residuals: bool,
preconditioner_computation_type: PreconditionerComputationType,
) -> torch.optim.Optimizer:
if optimizer_type == OptimizerType.SGD:
Expand Down Expand Up @@ -440,7 +434,6 @@ def instantiate_optimizer(
use_merge_dims=use_merge_dims,
distributed_config=distributed_config,
preconditioner_dtype=preconditioner_dtype.value,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_config=instantiate_preconditioner_config(
preconditioner_computation_type
),
Expand Down
63 changes: 0 additions & 63 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
DDPShampooConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DistributedConfig,
GraftingConfig,
Expand Down Expand Up @@ -225,21 +224,6 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None:
],
)

def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> None:
with self.assertRaisesRegex(
ValueError,
re.escape(
"track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig."
),
):
DistributedShampoo(
self._model.parameters(),
lr=0.01,
start_preconditioning_step=1,
track_root_inv_residuals=True,
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
)


class DistributedShampooTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -629,53 +613,6 @@ def test_load_distributed_state_dict_with_missing_param(self) -> None:
)


class DistributedShampooTrackRootInvResidualsTest(unittest.TestCase):
def _get_track_root_inverse_residuals_output(self, dtype: torch.dtype) -> list[str]:
# Create a model and a DistributedShampoo optimizer with enabled track_root_inv_residuals and corresponding dtype.
# The dtype of the model and the optimizer are the same.
model = nn.Sequential(nn.Linear(2, 1, bias=False))
model[0].weight.data = torch.tensor([1.0, 2.0], dtype=dtype)
optimizer = DistributedShampoo(
params=model.parameters(),
precondition_frequency=2,
start_preconditioning_step=2,
preconditioner_dtype=dtype,
track_root_inv_residuals=True,
)

# Run two steps of the optimizer to compute the root inverse residuals.
# Because precondition_frequency and start_preconditioning_step are both 2, there should be one call of
# _compute_and_log_root_inverse_residuals().
with self.assertLogs(level="DEBUG") as cm:
model[0].weight.grad = torch.tensor([1.0, 0.0], dtype=dtype)
optimizer.step()
model[0].weight.grad = torch.tensor([0.0, 1.0], dtype=dtype)
optimizer.step()
return [r.msg for r in cm.records]

def test_compute_and_log_root_inverse_residuals(self) -> None:
# Test the cases that tracking root inverse residuals support both float32 and float64.
for dtype, expected_relative_error in [
(torch.float32, 1e-3),
(torch.float64, 1e-7),
]:
with self.subTest(dtype=dtype):
msgs = self._get_track_root_inverse_residuals_output(dtype=dtype)
self.assertIn("Group Index: 0", msgs)
self.assertIn(
f"Expect Relative Error <= {expected_relative_error}", msgs
)

# Test the case that tracking root inverse residuals does not support float16.
msgs = self._get_track_root_inverse_residuals_output(dtype=torch.float16)
self.assertEqual(
msgs,
[
"Expected relative error/residual not supported for precision lower than float32."
],
)


class DistributedShampooNoneGradTest(unittest.TestCase):
def setUp(self) -> None:
self._model = nn.Sequential(
Expand Down
54 changes: 1 addition & 53 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
)
from distributed_shampoo.utils.shampoo_block_info import BlockInfo
from distributed_shampoo.utils.shampoo_utils import compress_list, get_dtype_size
from matrix_functions import (
check_diagonal,
compute_matrix_root_inverse_residuals,
matrix_eigenvectors,
matrix_inverse_root,
)
from matrix_functions import check_diagonal, matrix_eigenvectors, matrix_inverse_root

from matrix_functions_types import EigenvectorConfig, RootInvConfig
from optimizer_modules import OptimizerModule
Expand Down Expand Up @@ -924,53 +919,6 @@ def _amortized_computation(self) -> None:
)
inv_factor_matrix.copy_(computed_inv_factor_matrix)

@torch.compiler.disable
def compute_root_inverse_residuals(
self,
) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]:
root_inv_config = cast(
RootInvConfig,
self._preconditioner_config.amortized_computation_config,
)
relative_errors = []
relative_residuals = []

for kronecker_factors, root in zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
strict=True,
):
for factor_matrix, inv_factor_matrix in zip(
kronecker_factors.factor_matrices,
kronecker_factors.inv_factor_matrices,
strict=True,
):
bias_corrected_factor_matrix = factor_matrix / self._bias_correction2
(
relative_error,
relative_residual,
) = compute_matrix_root_inverse_residuals(
A=bias_corrected_factor_matrix,
X_hat=inv_factor_matrix,
root=Fraction(
root
/ getattr(
root_inv_config,
"exponent_multiplier",
1,
)
),
epsilon=self._epsilon,
root_inv_config=root_inv_config,
)
relative_errors.append(relative_error)
relative_residuals.append(relative_residual)

return (
tuple(relative_errors),
tuple(relative_residuals),
)


class EigenvalueCorrectedShampooPreconditionerList(
BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -703,48 +703,6 @@ def test_inverse_roots_from_override(
test_inverse_roots_from_override(inv_root_override=2)
test_inverse_roots_from_override(inv_root_override=[2, 2, 2])

def test_compute_root_inverse_residuals(self) -> None:
"""
Create a factor matrix of size 2x2 by updating preconditioners in two steps:
Step 1. G1 = [1, 0]^T
Step 2. G2 = [0, 1]^T
L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]]
"""
preconditioner_list = ShampooPreconditionerList(
block_list=(self._params[0],),
state=self._state,
block_info_list=(self._block_info_list[0],),
distributor_selector=(self._distributor_selector[0],),
preconditioner_config=DefaultShampooConfig,
epsilon=0.0,
)

masked_grad_list1 = (torch.tensor([1.0, 0.0]),)
masked_grad_list2 = (torch.tensor([0.0, 1.0]),)
preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list1,
step=torch.tensor(1),
perform_amortized_computation=False,
)
preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list2,
step=torch.tensor(2),
perform_amortized_computation=True,
)

# Expect no relative errors and residuals because L is a diagonal matrix.
(
relative_errors,
relative_residuals,
) = preconditioner_list.compute_root_inverse_residuals()

expected_relative_errors = (torch.tensor(0.0),)
expected_relative_residuals = (torch.tensor(0.0),)

self.assertTupleEqual(relative_errors, expected_relative_errors)
self.assertTupleEqual(relative_residuals, expected_relative_residuals)


class EigenvalueCorrectedShampooPreconditionerListTest(
AbstractTest.BaseShampooPreconditionerListTest
Expand Down

0 comments on commit a7b27f2

Please sign in to comment.