Skip to content

Commit

Permalink
update: test_schedule_free_train_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed May 5, 2024
1 parent 0f2d6b7 commit 723f9a0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
Ranger,
Ranger21,
ScalableShampoo,
ScheduleFreeAdamW,
ScheduleFreeSGD,
Shampoo,
SignSGD,
SophiaH,
Expand Down Expand Up @@ -124,6 +126,7 @@
'galore',
'adalite',
'bsam',
'schedulefreeadamw',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -439,6 +442,8 @@
5,
),
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
14 changes: 14 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,17 @@ def test_dynamic_scaler():
scaler = DynamicLossScaler(init_scale=2.0**15, scale_window=1, threshold=1e-2)
scaler.decrease_loss_scale()
scaler.update_scale(overflow=False)


def test_schedule_free_train_mode():
param = simple_parameter(True)

opt = load_optimizer('ScheduleFreeAdamW')([param])
opt.reset()
opt.train()
opt.eval()

opt = load_optimizer('ScheduleFreeSGD')([param])
opt.reset()
opt.train()
opt.eval()

0 comments on commit 723f9a0

Please sign in to comment.