diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index b7e7c99d0..0546f8fd1 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -464,6 +464,11 @@ def test_prodigy_reset(): assert str(optimizer) == 'Prodigy' +def test_adalite_reset(): + optimizer = load_optimizer('adalite')([simple_zero_rank_parameter(True)]) + optimizer.reset() + + @pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2]) def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type, environment): (x_data, y_data), _, loss_fn = environment