Skip to content

Commit

Permalink
feat(learner): use best score in early stopping improvement (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Jan 7, 2025
1 parent 1c2a749 commit 01fe903
Show file tree
Hide file tree
Showing 12 changed files with 20 additions and 10 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* feat: Added `n_layers` parameter to MLP
* BREAKING_CHANGE: Early stopping now not uses `epochs - patience` for the internally tuned
values instead of the trained number of `epochs` as it was before.
Also, the improvement is calcualted as the difference between the current and the best score,
not the current and the previous score.
* feat: Added multimodal melanoma example task
* feat: Added a callback to iteratively unfreeze parameters for finetuning
* fix: torch learners can now be used with `AutoTuner`
Expand Down
13 changes: 8 additions & 5 deletions R/CallbackSetEarlyStopping.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ CallbackSetEarlyStopping = R6Class("CallbackSetEarlyStopping",
self$patience = assert_int(patience, lower = 1L)
self$min_delta = assert_double(min_delta, lower = 0, len = 1L, any.missing = FALSE)
self$stagnation = 0L
self$best_score = NULL
},
on_valid_end = function() {
if (is.null(self$prev_valid_scores)) {
self$prev_valid_scores = self$ctx$last_scores_valid
if (is.null(self$ctx$last_scores_valid)) {
return(NULL)
}
if (is.null(self$ctx$last_scores_valid)) {
if (is.null(self$best_score)) {
self$best_score = self$ctx$last_scores_valid[[1L]]
return(NULL)
}
multiplier = if (self$ctx$measures_valid[[1L]]$minimize) -1 else 1
improvement = multiplier * (self$ctx$last_scores_valid[[1L]] - self$prev_valid_scores[[1L]])
improvement = multiplier * (self$ctx$last_scores_valid[[1L]] - self$best_score)

if (is.na(improvement)) {
lg$warn("Learner %s in epoch %s: Difference between subsequent validation performances is NA",
Expand All @@ -32,7 +33,9 @@ CallbackSetEarlyStopping = R6Class("CallbackSetEarlyStopping",
} else {
self$stagnation = 0
}
self$prev_valid_scores = self$ctx$last_scores_valid
if (improvement > 0) {
self$best_score = self$ctx$last_scores_valid[[1L]]
}
}
)
)
1 change: 0 additions & 1 deletion man/mlr_callback_set.checkpoint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/mlr_callback_set.history.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/mlr_callback_set.progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners.mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners.tab_resnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners.torch_featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners.torchvision.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_torch_image.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_torch_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 01fe903

Please sign in to comment.