From bb1bcacc11adb3b946bf33d1c31519ffc1bc94e1 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Thu, 16 Jan 2025 19:40:59 +0100 Subject: [PATCH] feat/callback-lr_schedule (#317) --- DESCRIPTION | 1 + NAMESPACE | 1 + NEWS.md | 1 + R/CallbackSetLRScheduler.R | 188 +++++++++++++++++++ R/TorchCallback.R | 9 +- R/TorchDescriptor.R | 10 +- man/TorchCallback.Rd | 6 +- man/TorchDescriptor.Rd | 6 +- man/mlr_callback_set.lr_scheduler.Rd | 92 +++++++++ man/mlr_learners.mlp.Rd | 1 - man/mlr_learners.tab_resnet.Rd | 1 - man/mlr_learners.torch_featureless.Rd | 3 +- man/mlr_learners.torchvision.Rd | 1 - man/mlr_learners_torch.Rd | 1 - man/mlr_learners_torch_image.Rd | 1 - man/mlr_learners_torch_model.Rd | 1 - man/torch_callback.Rd | 2 +- tests/testthat/helper_autotest.R | 12 +- tests/testthat/test_CallbackSetLRScheduler.R | 84 +++++++++ 19 files changed, 400 insertions(+), 21 deletions(-) create mode 100644 R/CallbackSetLRScheduler.R create mode 100644 man/mlr_callback_set.lr_scheduler.Rd create mode 100644 tests/testthat/test_CallbackSetLRScheduler.R diff --git a/DESCRIPTION b/DESCRIPTION index 1a1eccaa..4ecd6bae 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -85,6 +85,7 @@ Collate: 'CallbackSetCheckpoint.R' 'CallbackSetEarlyStopping.R' 'CallbackSetHistory.R' + 'CallbackSetLRScheduler.R' 'CallbackSetProgress.R' 'CallbackSetTB.R' 'CallbackSetUnfreeze.R' diff --git a/NAMESPACE b/NAMESPACE index f68d6a30..492b5139 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -63,6 +63,7 @@ S3method(unmarshal_model,learner_torch_model_marshaled) export(CallbackSet) export(CallbackSetCheckpoint) export(CallbackSetHistory) +export(CallbackSetLRScheduler) export(CallbackSetProgress) export(CallbackSetTB) export(CallbackSetUnfreeze) diff --git a/NEWS.md b/NEWS.md index f0dba4cd..2b77f81b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ * feat: Added multimodal melanoma example task * feat: Added a callback to iteratively unfreeze parameters for finetuning * fix: torch learners can now be used with `AutoTuner` +* feat: Added different learning rate schedulers as callbacks # mlr3torch 0.1.2 diff --git a/R/CallbackSetLRScheduler.R b/R/CallbackSetLRScheduler.R new file mode 100644 index 00000000..8269ff6f --- /dev/null +++ b/R/CallbackSetLRScheduler.R @@ -0,0 +1,188 @@ +#' @title Learning Rate Scheduling Callback +#' +#' @name mlr_callback_set.lr_scheduler +#' +#' @description +#' Changes the learning rate based on the schedule specified by a `torch::lr_scheduler`. +#' +#' As of this writing, the following are available: [torch::lr_cosine_annealing()], [torch::lr_lambda()], [torch::lr_multiplicative()], [torch::lr_one_cycle()], +#' [torch::lr_reduce_on_plateau()], [torch::lr_step()], and custom schedulers defined with [torch::lr_scheduler()]. +#' +#' @param .scheduler (`lr_scheduler_generator`)\cr +#' The `torch` scheduler generator (e.g. `torch::lr_step`). +#' @param ... (any)\cr +#' The scheduler-specific arguments +#' +#' @export +CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler", + inherit = CallbackSet, + lock_objects = FALSE, + public = list( + #' @field scheduler_fn (`lr_scheduler_generator`)\cr + #' The `torch` function that creates a learning rate scheduler + scheduler_fn = NULL, + #' @field scheduler (`LRScheduler`)\cr + #' The learning rate scheduler wrapped by this callback + scheduler = NULL, + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(.scheduler, step_on_epoch, ...) { + assert_class(.scheduler, "lr_scheduler_generator") + assert_flag(step_on_epoch) + + self$scheduler_fn = .scheduler + private$.scheduler_args = list(...) + if (step_on_epoch) { + self$on_epoch_end = function() self$scheduler$step() + } else { + self$on_batch_end = function() self$scheduler$step() + } + }, + #' @description + #' Creates the scheduler using the optimizer from the context + on_begin = function() { + self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args) + } + ), + private = list( + .scheduler_args = NULL + ) +) + +# some of the schedulers accept lists +# so they can treat different parameter groups differently +check_class_or_list = function(x, classname) { + if (is.list(x)) check_list(x, types = classname) else check_class(x, classname) +} + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_cosine_annealing", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + T_max = p_int(tags = c("train", "required")), + eta_min = p_dbl(default = 0, tags = "train"), + last_epoch = p_int(default = -1, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train") + ), + id = "lr_cosine_annealing", + label = "Cosine Annealing LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_cosine_annealing, step_on_epoch = TRUE) + ) +}) + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_lambda", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")), + last_epoch = p_int(default = -1, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train") + ), + id = "lr_scheduler", + label = "Multiplication by Function LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_lambda, step_on_epoch = TRUE) + ) +}) + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_multiplicative", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")), + last_epoch = p_int(default = -1, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train") + ), + id = "lr_multiplicative", + label = "Multiplication by Factor LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_multiplicative, step_on_epoch = TRUE) + ) +}) + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_one_cycle", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + max_lr = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "numeric")), + total_steps = p_int(default = NULL, special_vals = list(NULL), tags = "train"), + epochs = p_int(default = NULL, special_vals = list(NULL), tags = "train"), + steps_per_epoch = p_int(default = NULL, special_vals = list(NULL), tags = "train"), + pct_start = p_dbl(default = 0.3, tags = "train"), + anneal_strategy = p_fct(default = "cos", levels = c("cos", "linear")), # this is a string in the torch fn + cycle_momentum = p_lgl(default = TRUE, tags = "train"), + base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), + max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), + div_factor = p_dbl(default = 25, tags = "train"), + final_div_factor = p_dbl(default = 1e4, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train") + ), + id = "lr_one_cycle", + label = "1cyle LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_one_cycle, step_on_epoch = FALSE) + ) +}) + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_reduce_on_plateau", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + mode = p_fct(default = "min", levels = c("min", "max"), tags = "train"), + factor = p_dbl(default = 0.1, tags = "train"), + patience = p_int(default = 10, tags = "train"), + threshold = p_dbl(default = 1e-04, tags = "train"), + threshold_mode = p_fct(default = "rel", levels = c("rel", "abs"), tags = "train"), + cooldown = p_int(default = 0, tags = "train"), + min_lr = p_uty(default = 0, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")), + eps = p_dbl(default = 1e-08, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train") + ), + id = "lr_reduce_on_plateau", + label = "Reduce on Plateau LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_reduce_on_plateau, step_on_epoch = TRUE) + ) +}) + +#' @include TorchCallback.R +mlr3torch_callbacks$add("lr_step", function() { + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = ps( + step_size = p_int(tags = c("train", "required")), + gamma = p_dbl(default = 0.1, tags = "train"), + last_epoch = p_int(default = -1, tags = "train") + ), + id = "lr_step", + label = "Step Decay LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = torch::lr_step, step_on_epoch = TRUE) + ) +}) + +#' @param x (`function`)\cr +#' The `torch` scheduler generator defined using `torch::lr_scheduler()`. +#' @param step_on_epoch (`logical(1)`)\cr +#' Whether the scheduler steps after every epoch +as_lr_scheduler = function(x, step_on_epoch) { + assert_class(x, "lr_scheduler_generator") + assert_flag(step_on_epoch) + + class_name = class(x)[1L] + + TorchCallback$new( + callback_generator = CallbackSetLRScheduler, + param_set = inferps(x), + id = if (class_name == "") "lr_custom" else class_name, + label = "Custom LR Scheduler", + man = "mlr3torch::mlr_callback_set.lr_scheduler", + additional_args = list(.scheduler = x, step_on_epoch = step_on_epoch) + ) +} diff --git a/R/TorchCallback.R b/R/TorchCallback.R index ce1cf1b8..f7d396d7 100644 --- a/R/TorchCallback.R +++ b/R/TorchCallback.R @@ -192,8 +192,10 @@ TorchCallback = R6Class("TorchCallback", #' @template param_label #' @template param_packages #' @template param_man + #' @param additional_args (`any`)\cr + #' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler. initialize = function(callback_generator, param_set = NULL, id = NULL, - label = NULL, packages = NULL, man = NULL) { + label = NULL, packages = NULL, man = NULL, additional_args = NULL) { assert_class(callback_generator, "R6ClassGenerator") param_set = assert_param_set(param_set %??% inferps(callback_generator)) @@ -206,7 +208,8 @@ TorchCallback = R6Class("TorchCallback", param_set = param_set, packages = union(packages, "mlr3torch"), label = label, - man = man + man = man, + additional_args = additional_args ) } ), @@ -215,7 +218,7 @@ TorchCallback = R6Class("TorchCallback", ) ) -#' @title Create a Callback Desctiptor +#' @title Create a Callback Descriptor #' #' @description #' Convenience function to create a custom [`TorchCallback`]. diff --git a/R/TorchDescriptor.R b/R/TorchDescriptor.R index 1696db95..e8170430 100644 --- a/R/TorchDescriptor.R +++ b/R/TorchDescriptor.R @@ -37,7 +37,9 @@ TorchDescriptor = R6Class("TorchDescriptor", #' @template param_packages #' @template param_label #' @template param_man - initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL) { + #' @param additional_args (`list()`)\cr + #' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler. + initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL, additional_args = NULL) { assert_true(is.function(generator) || inherits(generator, "R6ClassGenerator")) self$generator = generator self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator) @@ -63,6 +65,7 @@ TorchDescriptor = R6Class("TorchDescriptor", self$id = assert_string(id %??% class(generator)[[1L]], min.chars = 1L) self$label = assert_string(label %??% self$id, min.chars = 1L) self$packages = assert_names(unique(union(packages, c("torch", "mlr3torch"))), type = "strict") + private$.additional_args = assert_list(additional_args, null.ok = TRUE) }, #' @description #' Prints the object @@ -86,9 +89,9 @@ TorchDescriptor = R6Class("TorchDescriptor", # The torch generators could also be constructed with the $new() method, but then the return value # would be the actual R6 class and not the wrapped function. if (is.function(self$generator)) { - invoke(self$generator, .args = self$param_set$get_values()) + invoke(self$generator, .args = c(self$param_set$get_values(), private$.additional_args)) } else { - invoke(self$generator$new, .args = self$param_set$get_values()) + invoke(self$generator$new, .args = c(self$param_set$get_values(), private$.additional_args)) } }, #' @description @@ -107,6 +110,7 @@ TorchDescriptor = R6Class("TorchDescriptor", } ), private = list( + .additional_args = NULL, .additional_phash_input = function() { stopf("Classes inheriting from TorchDescriptor must implement the .additional_phash_input() method.") }, diff --git a/man/TorchCallback.Rd b/man/TorchCallback.Rd index 3af77676..80d8309c 100644 --- a/man/TorchCallback.Rd +++ b/man/TorchCallback.Rd @@ -113,7 +113,8 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. id = NULL, label = NULL, packages = NULL, - man = NULL + man = NULL, + additional_args = NULL )}\if{html}{\out{}} } @@ -138,6 +139,9 @@ The R packages this object depends on.} \item{\code{man}}{(\code{character(1)})\cr String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object. The referenced help package can be opened via method \verb{$help()}.} + +\item{\code{additional_args}}{(\code{any})\cr +Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.} } \if{html}{\out{}} } diff --git a/man/TorchDescriptor.Rd b/man/TorchDescriptor.Rd index 26ad886a..b4871137 100644 --- a/man/TorchDescriptor.Rd +++ b/man/TorchDescriptor.Rd @@ -88,7 +88,8 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. param_set = NULL, packages = NULL, label = NULL, - man = NULL + man = NULL, + additional_args = NULL )}\if{html}{\out{}} } @@ -112,6 +113,9 @@ Label for the new instance.} \item{\code{man}}{(\code{character(1)})\cr String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object. The referenced help package can be opened via method \verb{$help()}.} + +\item{\code{additional_args}}{(\code{list()})\cr +Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.} } \if{html}{\out{}} } diff --git a/man/mlr_callback_set.lr_scheduler.Rd b/man/mlr_callback_set.lr_scheduler.Rd new file mode 100644 index 00000000..9953f5b2 --- /dev/null +++ b/man/mlr_callback_set.lr_scheduler.Rd @@ -0,0 +1,92 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackSetLRScheduler.R +\name{mlr_callback_set.lr_scheduler} +\alias{mlr_callback_set.lr_scheduler} +\alias{CallbackSetLRScheduler} +\title{Learning Rate Scheduling Callback} +\description{ +Changes the learning rate based on the schedule specified by a \code{torch::lr_scheduler}. + +As of this writing, the following are available: \code{\link[torch:lr_cosine_annealing]{torch::lr_cosine_annealing()}}, \code{\link[torch:lr_lambda]{torch::lr_lambda()}}, \code{\link[torch:lr_multiplicative]{torch::lr_multiplicative()}}, \code{\link[torch:lr_one_cycle]{torch::lr_one_cycle()}}, +\code{\link[torch:lr_reduce_on_plateau]{torch::lr_reduce_on_plateau()}}, \code{\link[torch:lr_step]{torch::lr_step()}}, and custom schedulers defined with \code{\link[torch:lr_scheduler]{torch::lr_scheduler()}}. +} +\section{Super class}{ +\code{\link[mlr3torch:CallbackSet]{mlr3torch::CallbackSet}} -> \code{CallbackSetLRScheduler} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{scheduler_fn}}{(\code{lr_scheduler_generator})\cr +The \code{torch} function that creates a learning rate scheduler} + +\item{\code{scheduler}}{(\code{LRScheduler})\cr +The learning rate scheduler wrapped by this callback} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-CallbackSetLRScheduler-new}{\code{CallbackSetLRScheduler$new()}} +\item \href{#method-CallbackSetLRScheduler-on_begin}{\code{CallbackSetLRScheduler$on_begin()}} +\item \href{#method-CallbackSetLRScheduler-clone}{\code{CallbackSetLRScheduler$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetLRScheduler-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetLRScheduler$new(.scheduler, step_on_epoch, ...)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{.scheduler}}{(\code{lr_scheduler_generator})\cr +The \code{torch} scheduler generator (e.g. \code{torch::lr_step}).} + +\item{\code{...}}{(any)\cr +The scheduler-specific arguments} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetLRScheduler-on_begin}{}}} +\subsection{Method \code{on_begin()}}{ +Creates the scheduler using the optimizer from the context +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetLRScheduler$on_begin()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetLRScheduler-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetLRScheduler$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 276572f1..7f03f5c7 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -107,7 +107,6 @@ Other Learner:
Inherited methods