diff --git a/tests/constants.py b/tests/constants.py index d1572c195..65c0afb33 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -59,6 +59,8 @@ Ranger, Ranger21, ScalableShampoo, + ScheduleFreeAdamW, + ScheduleFreeSGD, Shampoo, SignSGD, SophiaH, @@ -124,6 +126,7 @@ 'galore', 'adalite', 'bsam', + 'schedulefreeadamw', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -439,6 +442,8 @@ 5, ), (Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5), + (ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5), + (ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index c543c7442..4b68b6e7d 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -594,3 +594,17 @@ def test_dynamic_scaler(): scaler = DynamicLossScaler(init_scale=2.0**15, scale_window=1, threshold=1e-2) scaler.decrease_loss_scale() scaler.update_scale(overflow=False) + + +def test_schedule_free_train_mode(): + param = simple_parameter(True) + + opt = load_optimizer('ScheduleFreeAdamW')([param]) + opt.reset() + opt.train() + opt.eval() + + opt = load_optimizer('ScheduleFreeSGD')([param]) + opt.reset() + opt.train() + opt.eval()