Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StepLR scheduler #109

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,17 @@ The config file has three main sections:
- `lr`: (float) Learning rate of type float. *Default*: 1e-3
- `amsgrad`: (bool) Enable AMSGrad with the optimizer. *Default*: False
- `lr_scheduler`
- `mode`: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. *Default*: "min".
- `threshold`: (float) Threshold for measuring the new optimum, to only focus on significant changes. *Default*: 1e-4.
- `threshold_mode`: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. *Default*: "rel".
- `cooldown`: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. *Default*: 0
- `patience`: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn’t improved then. *Default*: 10.
- `factor`: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. *Default*: 0.1.
- `min_lr`: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. *Default*: 0.
- `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
- `step_lr`:
- `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.
- `gamma`: (float) Multiplicative factor of learning rate decay.*Default*: 0.1.
- `reduce_lr_on_plateau`:
- `threshold`: (float) Threshold for measuring the new optimum, to only focus on significant changes. *Default*: 1e-4.
- `threshold_mode`: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. *Default*: "rel".
- `cooldown`: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. *Default*: 0
- `patience`: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn’t improved then. *Default*: 10.
- `factor`: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. *Default*: 0.1.
- `min_lr`: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. *Default*: 0.
- `early_stopping`
- `stop_training_on_plateau`: (bool) True if early stopping should be enabled.
- `min_delta`: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
Expand Down
15 changes: 9 additions & 6 deletions docs/config_bottomup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_config:
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
chunk_size: 100
preprocessing:
max_width: null
max_height: null
Expand Down Expand Up @@ -104,12 +105,14 @@ trainer_config:
lr: 0.0001
amsgrad: false
lr_scheduler:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
early_stopping:
stop_training_on_plateau: true
min_delta: 1.0e-08
Expand Down
15 changes: 9 additions & 6 deletions docs/config_centroid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_config:
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
chunk_size: 100
preprocessing:
max_width:
max_height:
Expand Down Expand Up @@ -133,12 +134,14 @@ trainer_config:
lr: 0.0001
amsgrad: false
lr_scheduler:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
early_stopping:
stop_training_on_plateau: True
min_delta: 1.0e-08
Expand Down
16 changes: 10 additions & 6 deletions docs/config_topdown_centered_instance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data_config:
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
chunk_size: 100
preprocessing:
max_width:
max_height:
Expand All @@ -11,6 +12,7 @@ data_config:
crop_hw:
- 160
- 160
min_crop_size:
use_augmentations_train: true
augmentation_config:
geometric:
Expand Down Expand Up @@ -102,12 +104,14 @@ trainer_config:
lr: 0.0001
amsgrad: false
lr_scheduler:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
stop_training_on_plateau: True
min_delta: 1.0e-08
Expand Down
104 changes: 104 additions & 0 deletions initial_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
data_config:
provider: LabelsReader
train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
user_instances_only: true
chunk_size: 100
preprocessing:
is_rgb: false
max_width: null
max_height: null
scale: 1.0
crop_hw:
- 160
- 160
min_crop_size: null
use_augmentations_train: true
augmentation_config:
intensity:
contrast_p: 1.0
geometric:
rotation: 180.0
scale: null
translate_width: 0
translate_height: 0
affine_p: 0.5
model_config:
init_weights: default
pre_trained_weights: null
backbone_type: unet
backbone_config:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 1.5
max_stride: 8
convs_per_block: 2
stacks: 1
stem_stride: null
middle_block: true
up_interpolate: true
head_configs:
single_instance: null
centroid: null
bottomup: null
centered_instance:
confmaps:
part_names:
- '0'
- '1'
anchor_part: 1
sigma: 1.5
output_stride: 2
trainer_config:
train_data_loader:
batch_size: 1
shuffle: true
num_workers: 2
val_data_loader:
batch_size: 1
num_workers: 0
model_ckpt:
save_top_k: 1
save_last: true
early_stopping:
stop_training_on_plateau: true
min_delta: 1.0e-08
patience: 20
trainer_devices: 1
trainer_accelerator: cpu
enable_progress_bar: false
steps_per_epoch: null
max_epochs: 2
seed: 1000
use_wandb: false
save_ckpt: false
save_ckpt_path: null
bin_files_path: null
resume_ckpt_path: null
wandb:
entity: null
project: test
name: test_run
wandb_mode: offline
api_key: ''
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
prv_runid: null
log_params:
- trainer_config.optimizer_name
- trainer_config.optimizer.amsgrad
- trainer_config.optimizer.lr
- model_config.backbone_type
- model_config.init_weights
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: rel
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions sleap_nn/data/streaming_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def __init__(
self.apply_aug = apply_aug
self.aug_config = augmentation_config
self.input_scale = input_scale
# Re-crop to original crop size
self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]

def __getitem__(self, index):
"""Apply augmentation and generate confidence maps."""
Expand All @@ -170,8 +172,6 @@ def __getitem__(self, index):
ex["instance_image"], ex["instance"], **self.aug_config.geometric
)

# Re-crop to original crop size
self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]
ex["instance_bbox"] = torch.unsqueeze(
make_centered_bboxes(ex["centroid"][0], self.crop_hw[0], self.crop_hw[1]), 0
)
Expand Down
40 changes: 30 additions & 10 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,16 +636,36 @@ def configure_optimizers(self):
lr=self.trainer_config.optimizer.lr,
amsgrad=self.trainer_config.optimizer.amsgrad,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
threshold=self.trainer_config.lr_scheduler.threshold,
threshold_mode=self.trainer_config.lr_scheduler.threshold_mode,
cooldown=self.trainer_config.lr_scheduler.cooldown,
patience=self.trainer_config.lr_scheduler.patience,
factor=self.trainer_config.lr_scheduler.factor,
min_lr=self.trainer_config.lr_scheduler.min_lr,
)

if self.trainer_config.lr_scheduler.scheduler == "StepLR":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer,
step_size=self.trainer_config.lr_scheduler.step_lr.step_size,
gamma=self.trainer_config.lr_scheduler.step_lr.gamma,
)

elif self.trainer_config.lr_scheduler.scheduler == "ReduceLROnPlateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
threshold=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.threshold,
threshold_mode=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.threshold_mode,
cooldown=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.cooldown,
patience=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.patience,
factor=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.factor,
min_lr=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.min_lr,
)

elif self.trainer_config.lr_scheduler.scheduler is not None:
raise ValueError(
f"{self.trainer_config.lr_scheduler.scheduler} is not a valid scheduler. Valid schedulers: `'StepLR'`, `'ReduceLROnPlateau'`"
)

if self.trainer_config.lr_scheduler.scheduler is None:
return {
"optimizer": optimizer,
}

gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
return {
"optimizer": optimizer,
"lr_scheduler": {
Expand Down
21 changes: 14 additions & 7 deletions tests/assets/minimal_instance/initial_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ data_config:
provider: LabelsReader
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
chunk_size: 100
preprocessing:
max_width: null
max_height: null
scale: 1.0
is_rgb: false
crop_hw:
- 160
- 160
- 160
- 160
min_crop_size:
use_augmentations_train: true
augmentation_config:
geometric:
Expand Down Expand Up @@ -63,6 +66,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: min_inst_centered
bin_files_path:
resume_ckpt_path: null
wandb:
entity: null
Expand All @@ -82,11 +86,14 @@ trainer_config:
lr: 0.0001
amsgrad: false
lr_scheduler:
threshold: 1.0e-07
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
stop_training_on_plateau: true
min_delta: 1.0e-08
Expand Down
17 changes: 12 additions & 5 deletions tests/assets/minimal_instance/training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ data_config:
provider: LabelsReader
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: true
chunk_size: 100
preprocessing:
max_width: null
max_height: null
Expand All @@ -10,6 +12,7 @@ data_config:
crop_hw:
- 160
- 160
min_crop_size: null
use_augmentations_train: true
augmentation_config:
geometric:
Expand Down Expand Up @@ -76,6 +79,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: min_inst_centered
bin_files_path: null
resume_ckpt_path: null
wandb:
entity: null
Expand All @@ -95,11 +99,14 @@ trainer_config:
lr: 0.0001
amsgrad: false
lr_scheduler:
threshold: 1.0e-07
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
early_stopping:
stop_training_on_plateau: true
min_delta: 1.0e-08
Expand Down
Loading
Loading