Skip to content

Commit

Permalink
feat/callback-lr_schedule (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 authored Jan 16, 2025
1 parent 186c554 commit bb1bcac
Show file tree
Hide file tree
Showing 19 changed files with 400 additions and 21 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Collate:
'CallbackSetCheckpoint.R'
'CallbackSetEarlyStopping.R'
'CallbackSetHistory.R'
'CallbackSetLRScheduler.R'
'CallbackSetProgress.R'
'CallbackSetTB.R'
'CallbackSetUnfreeze.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ S3method(unmarshal_model,learner_torch_model_marshaled)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
export(CallbackSetLRScheduler)
export(CallbackSetProgress)
export(CallbackSetTB)
export(CallbackSetUnfreeze)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* feat: Added multimodal melanoma example task
* feat: Added a callback to iteratively unfreeze parameters for finetuning
* fix: torch learners can now be used with `AutoTuner`
* feat: Added different learning rate schedulers as callbacks

# mlr3torch 0.1.2

Expand Down
188 changes: 188 additions & 0 deletions R/CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#' @title Learning Rate Scheduling Callback
#'
#' @name mlr_callback_set.lr_scheduler
#'
#' @description
#' Changes the learning rate based on the schedule specified by a `torch::lr_scheduler`.
#'
#' As of this writing, the following are available: [torch::lr_cosine_annealing()], [torch::lr_lambda()], [torch::lr_multiplicative()], [torch::lr_one_cycle()],
#' [torch::lr_reduce_on_plateau()], [torch::lr_step()], and custom schedulers defined with [torch::lr_scheduler()].
#'
#' @param .scheduler (`lr_scheduler_generator`)\cr
#' The `torch` scheduler generator (e.g. `torch::lr_step`).
#' @param ... (any)\cr
#' The scheduler-specific arguments
#'
#' @export
CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler",
inherit = CallbackSet,
lock_objects = FALSE,
public = list(
#' @field scheduler_fn (`lr_scheduler_generator`)\cr
#' The `torch` function that creates a learning rate scheduler
scheduler_fn = NULL,
#' @field scheduler (`LRScheduler`)\cr
#' The learning rate scheduler wrapped by this callback
scheduler = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(.scheduler, step_on_epoch, ...) {
assert_class(.scheduler, "lr_scheduler_generator")
assert_flag(step_on_epoch)

self$scheduler_fn = .scheduler
private$.scheduler_args = list(...)
if (step_on_epoch) {
self$on_epoch_end = function() self$scheduler$step()
} else {
self$on_batch_end = function() self$scheduler$step()
}
},
#' @description
#' Creates the scheduler using the optimizer from the context
on_begin = function() {
self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args)
}
),
private = list(
.scheduler_args = NULL
)
)

# some of the schedulers accept lists
# so they can treat different parameter groups differently
check_class_or_list = function(x, classname) {
if (is.list(x)) check_list(x, types = classname) else check_class(x, classname)
}

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_cosine_annealing", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
T_max = p_int(tags = c("train", "required")),
eta_min = p_dbl(default = 0, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_cosine_annealing",
label = "Cosine Annealing LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_cosine_annealing, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_lambda", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_scheduler",
label = "Multiplication by Function LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_lambda, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_multiplicative", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
lr_lambda = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_multiplicative",
label = "Multiplication by Factor LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_multiplicative, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_one_cycle", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
max_lr = p_uty(tags = c("train", "required"), custom_check = function(x) check_class_or_list(x, "numeric")),
total_steps = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
epochs = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
steps_per_epoch = p_int(default = NULL, special_vals = list(NULL), tags = "train"),
pct_start = p_dbl(default = 0.3, tags = "train"),
anneal_strategy = p_fct(default = "cos", levels = c("cos", "linear")), # this is a string in the torch fn
cycle_momentum = p_lgl(default = TRUE, tags = "train"),
base_momentum = p_uty(default = 0.85, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
max_momentum = p_uty(default = 0.95, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
div_factor = p_dbl(default = 25, tags = "train"),
final_div_factor = p_dbl(default = 1e4, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_one_cycle",
label = "1cyle LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_one_cycle, step_on_epoch = FALSE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_reduce_on_plateau", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
mode = p_fct(default = "min", levels = c("min", "max"), tags = "train"),
factor = p_dbl(default = 0.1, tags = "train"),
patience = p_int(default = 10, tags = "train"),
threshold = p_dbl(default = 1e-04, tags = "train"),
threshold_mode = p_fct(default = "rel", levels = c("rel", "abs"), tags = "train"),
cooldown = p_int(default = 0, tags = "train"),
min_lr = p_uty(default = 0, tags = "train", custom_check = function(x) check_class_or_list(x, "numeric")),
eps = p_dbl(default = 1e-08, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
),
id = "lr_reduce_on_plateau",
label = "Reduce on Plateau LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_reduce_on_plateau, step_on_epoch = TRUE)
)
})

#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_step", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
step_size = p_int(tags = c("train", "required")),
gamma = p_dbl(default = 0.1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train")
),
id = "lr_step",
label = "Step Decay LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = torch::lr_step, step_on_epoch = TRUE)
)
})

#' @param x (`function`)\cr
#' The `torch` scheduler generator defined using `torch::lr_scheduler()`.
#' @param step_on_epoch (`logical(1)`)\cr
#' Whether the scheduler steps after every epoch
as_lr_scheduler = function(x, step_on_epoch) {
assert_class(x, "lr_scheduler_generator")
assert_flag(step_on_epoch)

class_name = class(x)[1L]

TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = inferps(x),
id = if (class_name == "") "lr_custom" else class_name,
label = "Custom LR Scheduler",
man = "mlr3torch::mlr_callback_set.lr_scheduler",
additional_args = list(.scheduler = x, step_on_epoch = step_on_epoch)
)
}
9 changes: 6 additions & 3 deletions R/TorchCallback.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ TorchCallback = R6Class("TorchCallback",
#' @template param_label
#' @template param_packages
#' @template param_man
#' @param additional_args (`any`)\cr
#' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.
initialize = function(callback_generator, param_set = NULL, id = NULL,
label = NULL, packages = NULL, man = NULL) {
label = NULL, packages = NULL, man = NULL, additional_args = NULL) {
assert_class(callback_generator, "R6ClassGenerator")

param_set = assert_param_set(param_set %??% inferps(callback_generator))
Expand All @@ -206,7 +208,8 @@ TorchCallback = R6Class("TorchCallback",
param_set = param_set,
packages = union(packages, "mlr3torch"),
label = label,
man = man
man = man,
additional_args = additional_args
)
}
),
Expand All @@ -215,7 +218,7 @@ TorchCallback = R6Class("TorchCallback",
)
)

#' @title Create a Callback Desctiptor
#' @title Create a Callback Descriptor
#'
#' @description
#' Convenience function to create a custom [`TorchCallback`].
Expand Down
10 changes: 7 additions & 3 deletions R/TorchDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ TorchDescriptor = R6Class("TorchDescriptor",
#' @template param_packages
#' @template param_label
#' @template param_man
initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL) {
#' @param additional_args (`list()`)\cr
#' Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.
initialize = function(generator, id = NULL, param_set = NULL, packages = NULL, label = NULL, man = NULL, additional_args = NULL) {
assert_true(is.function(generator) || inherits(generator, "R6ClassGenerator"))
self$generator = generator
self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator)
Expand All @@ -63,6 +65,7 @@ TorchDescriptor = R6Class("TorchDescriptor",
self$id = assert_string(id %??% class(generator)[[1L]], min.chars = 1L)
self$label = assert_string(label %??% self$id, min.chars = 1L)
self$packages = assert_names(unique(union(packages, c("torch", "mlr3torch"))), type = "strict")
private$.additional_args = assert_list(additional_args, null.ok = TRUE)
},
#' @description
#' Prints the object
Expand All @@ -86,9 +89,9 @@ TorchDescriptor = R6Class("TorchDescriptor",
# The torch generators could also be constructed with the $new() method, but then the return value
# would be the actual R6 class and not the wrapped function.
if (is.function(self$generator)) {
invoke(self$generator, .args = self$param_set$get_values())
invoke(self$generator, .args = c(self$param_set$get_values(), private$.additional_args))
} else {
invoke(self$generator$new, .args = self$param_set$get_values())
invoke(self$generator$new, .args = c(self$param_set$get_values(), private$.additional_args))
}
},
#' @description
Expand All @@ -107,6 +110,7 @@ TorchDescriptor = R6Class("TorchDescriptor",
}
),
private = list(
.additional_args = NULL,
.additional_phash_input = function() {
stopf("Classes inheriting from TorchDescriptor must implement the .additional_phash_input() method.")
},
Expand Down
6 changes: 5 additions & 1 deletion man/TorchCallback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/TorchDescriptor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit bb1bcac

Please sign in to comment.