Skip to content

Commit

Permalink
Harden lr scheduler (#932)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyaoj authored Dec 20, 2024
1 parent 7923e6e commit 49e20be
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/fairseq2/optim/lr_scheduler/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,18 @@ def create_cosine_annealing_lr(
# Validate config and set final_lr
if (config.final_lr is not None) and (config.final_lr_scale is not None):
raise ValueError(
f"Invalid configuration: Both `final_lr` ({config.final_lr}) and `final_lr_scale` "
f"({config.final_lr_scale}) are set. Please specify only one."
)

if (config.final_lr is None) and (config.final_lr_scale is None):
raise ValueError(
"Invalid configuration: Either `final_lr` or `final_lr_scale` must be specified."
f"Invalid configuration: Both `final_lr` ({config.final_lr}) and `final_lr_scale` ({config.final_lr_scale}) are set. Please specify only one."
)

# Compute final_lr based on the configuration
if config.final_lr_scale is not None:
final_lr = optimizer.param_groups[0]["lr"] * config.final_lr_scale
elif config.final_lr is not None:
final_lr = config.final_lr
else:
final_lr = config.final_lr # type: ignore
raise ValueError(
"Invalid configuration: Either `final_lr` or `final_lr_scale` must be specified."
)

if final_lr > optimizer.param_groups[0]["lr"]:
log.warning(
Expand Down
144 changes: 144 additions & 0 deletions tests/unit/optim/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
PolynomialDecayLR,
TriStageLR,
)
from fairseq2.optim.lr_scheduler.factory import (
CosineAnnealingLRConfig,
PolynomialDecayLRConfig,
TriStageLRConfig,
create_cosine_annealing_lr,
create_polynomial_decay_lr,
create_tri_stage_lr,
)


class LRSchedulerTestNet(Module):
Expand Down Expand Up @@ -544,3 +552,139 @@ def test_tristage(self) -> None:

assert lr1 == pytest.approx(final_lr1)
assert lr2 == pytest.approx(final_lr2)


class TestLRSchedulerFactory:
# Common constants
BASE_LR1: float = 0.05
BASE_LR2: float = 0.5
NUM_WARMUP_STEPS: int = 100
START_LR: float = 0.01
FINAL_LR: float = 0.02
NUM_STEPS: int = 200

# CosineAnnealingLR constants
CYCLE_LEN: int = 80
CYCLE_MUL: float = 1.2
LR_MUL: float = 0.5
FINAL_LR_SCALE: float = 0.2
MAX_NUM_STEPS: int = 1000

# PolynomialDecayLR constants
POLY_POWER: float = 1.5

# TriStageLR constants
TRI_STAGE_RATIO: tuple[float, float, float] = (0.1, 0.4, 0.5)
TRI_START_LR_SCALE: float = 0.05
TRI_FINAL_LR_SCALE: float = 0.1

def setup_method(self) -> None:
"""Set up the test environment with base learning rates and an optimizer."""
self.net = LRSchedulerTestNet()
self.opt = SGD(
params=[ # type: ignore[arg-type]
{"params": self.net.conv1.parameters()},
{"params": self.net.conv2.parameters(), "lr": self.BASE_LR2},
],
lr=self.BASE_LR1,
)

def test_create_cosine_annealing_lr(self) -> None:
"""Test creation of a CosineAnnealingLR with various configurations."""
# Test with final_lr
config = CosineAnnealingLRConfig(
cycle_len=self.CYCLE_LEN,
num_warmup_steps=self.NUM_WARMUP_STEPS,
cycle_mul=self.CYCLE_MUL,
lr_mul=self.LR_MUL,
start_lr=self.START_LR,
final_lr=self.FINAL_LR,
final_lr_scale=None,
)
scheduler = create_cosine_annealing_lr(config, self.opt, self.MAX_NUM_STEPS)

assert isinstance(scheduler, CosineAnnealingLR)
assert scheduler.get_last_lr() == [self.START_LR, self.START_LR]

# Test with final_lr_scale
config = CosineAnnealingLRConfig(
cycle_len=self.CYCLE_LEN,
num_warmup_steps=self.NUM_WARMUP_STEPS,
final_lr=None,
final_lr_scale=self.FINAL_LR_SCALE,
)
scheduler = create_cosine_annealing_lr(config, self.opt, None)
assert isinstance(scheduler, CosineAnnealingLR)

@pytest.mark.parametrize(
"final_lr, final_lr_scale, match_pattern",
[
(0.02, 0.2, "Both `final_lr` .* and `final_lr_scale` .* are set"),
(None, None, "Either `final_lr` or `final_lr_scale` must be specified"),
],
)
def test_cosine_annealing_lr_final_lr_errors(
self, final_lr: float | None, final_lr_scale: float | None, match_pattern: str
) -> None:
"""Test error scenarios for final_lr and final_lr_scale in CosineAnnealingLR."""
config = CosineAnnealingLRConfig(
final_lr=final_lr, final_lr_scale=final_lr_scale
)
with pytest.raises(ValueError, match=match_pattern):
create_cosine_annealing_lr(config, self.opt, self.MAX_NUM_STEPS)

def test_cosine_annealing_lr_cycle_len_error(self) -> None:
"""Test error when cycle_len is None and max_num_steps is also None."""
with pytest.raises(ValueError, match="`cycle_len` must be specified"):
config = CosineAnnealingLRConfig(cycle_len=None)
create_cosine_annealing_lr(config, self.opt, None)

def test_create_polynomial_decay_lr(self) -> None:
"""Test creation of a PolynomialDecayLR with various configurations."""
config = PolynomialDecayLRConfig(
num_steps=self.NUM_STEPS,
num_warmup_steps=self.NUM_WARMUP_STEPS,
power=self.POLY_POWER,
start_lr=self.START_LR,
final_lr=self.FINAL_LR,
)
scheduler = create_polynomial_decay_lr(config, self.opt, None)

assert isinstance(scheduler, PolynomialDecayLR)
assert scheduler.get_last_lr() == [self.START_LR, self.START_LR]

# Test with num_steps=None and max_num_steps provided
config = PolynomialDecayLRConfig(num_steps=None)
scheduler = create_polynomial_decay_lr(config, self.opt, self.MAX_NUM_STEPS)
assert isinstance(scheduler, PolynomialDecayLR)

# Test error when both num_steps and max_num_steps are None
with pytest.raises(ValueError, match="`max_num_steps` must be specified"):
config = PolynomialDecayLRConfig(num_steps=None)
create_polynomial_decay_lr(config, self.opt, None)

def test_create_tri_stage_lr(self) -> None:
"""Test creation of a TriStageLR with various configurations."""
config = TriStageLRConfig(
num_steps=self.NUM_STEPS,
stage_ratio=self.TRI_STAGE_RATIO,
start_lr_scale=self.TRI_START_LR_SCALE,
final_lr_scale=self.TRI_FINAL_LR_SCALE,
)
scheduler = create_tri_stage_lr(config, self.opt, None)

expected_lr1 = self.BASE_LR1 * self.TRI_START_LR_SCALE
expected_lr2 = self.BASE_LR2 * self.TRI_START_LR_SCALE

assert isinstance(scheduler, TriStageLR)
assert scheduler.get_last_lr() == [expected_lr1, expected_lr2]

# Test with num_steps=None and max_num_steps provided
config = TriStageLRConfig(num_steps=None)
scheduler = create_tri_stage_lr(config, self.opt, self.MAX_NUM_STEPS)
assert isinstance(scheduler, TriStageLR)

# Test error when both num_steps and max_num_steps are None
with pytest.raises(ValueError, match="`max_num_steps` must be specified"):
config = TriStageLRConfig(num_steps=None)
create_tri_stage_lr(config, self.opt, None)

0 comments on commit 49e20be

Please sign in to comment.