Skip to content

Commit

Permalink
note to self
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Jan 9, 2025
1 parent aa74077 commit 210de81
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions R/CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ CallbackSetLRScheduler = R6Class("CallbackSetLRScheduler",
#' @description
#' Creates the scheduler using the optimizer from the context
on_begin = function() {
# TODO: check that the .scheduler_args do not have the cb prefix (pretty sure this is trues)
self$scheduler = invoke(self$scheduler_fn, optimizer = self$ctx$optimizer, .args = private$.scheduler_args)
},
#' @description
Expand Down Expand Up @@ -68,6 +69,7 @@ mlr3torch_callbacks$add("lr_scheduler_cosine_annealing", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
T_max = p_int(tags = c("train", "required")),
eta_min = p_dbl(default = 0, lower = 0, tags = "train"),
last_epoch = p_int(default = -1, tags = "train"),
Expand All @@ -86,6 +88,7 @@ mlr3torch_callbacks$add("lr_scheduler_lambda", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
lr_lambda = p_uty(tags = c("train"), custom_check = function(x) check_class_or_list(x, "function")), # TODO: assert fn or list of fns
last_epoch = p_int(default = -1, lower = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
Expand All @@ -102,6 +105,7 @@ mlr3torch_callbacks$add("lr_scheduler_multiplicative", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
lr_lambda = p_uty(tags = c("train"), custom_check = function(x) check_class_or_list(x, "function")),
last_epoch = p_int(default = -1, lower = -1, tags = "train"),
verbose = p_lgl(default = FALSE, tags = "train")
Expand All @@ -113,11 +117,13 @@ mlr3torch_callbacks$add("lr_scheduler_multiplicative", function() {
)
})

# TODO: refactor to operate on batches
#' @include TorchCallback.R
mlr3torch_callbacks$add("lr_scheduler_one_cycle", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
max_lr = p_dbl(tags = "train"),
total_steps = p_int(default = NULL, tags = "train"),
epochs = p_int(default = NULL, tags = "train"),
Expand All @@ -143,6 +149,7 @@ mlr3torch_callbacks$add("lr_scheduler_reduce_on_plateau", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
mode = p_fct(default = "min", levels = c("min", "max"), tags = "train"),
factor = p_dbl(default = 0.1, tags = "train"),
patience = p_int(default = 10, tags = "train"),
Expand All @@ -165,6 +172,7 @@ mlr3torch_callbacks$add("lr_scheduler_step", function() {
TorchCallback$new(
callback_generator = CallbackSetLRScheduler,
param_set = ps(
.scheduler = p_uty(tags = c("train", "required")),
step_size = p_int(default = 1, lower = 1, tags = "train"),
gamma = p_dbl(default = 0.1, lower = 0, upper = 1, tags = "train"),
last_epoch = p_int(default = -1, tags = "train")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_CallbackSetLRScheduler.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
test_that("autotest", {
cb = t_clbk("lr_scheduler_cosine_annealing", T_max = 1)
# TODO: figure out how to set .scheduler and T_max (and similar)
# expect_torch_callback(cb)
expect_torch_callback(cb)
})

test_that("decay works", {
Expand Down

0 comments on commit 210de81

Please sign in to comment.