Skip to content

Commit

Permalink
Feat/inner valid (#294)
Browse files Browse the repository at this point in the history
* anticipate mlr3 and mlr3pipelines changes

* update workflows

* typo in workflow

* renaming

* fix workflow

* ...

* wip

* ...

* ...

* fix: add patterns to global variables

---------

Co-authored-by: be-marc <[email protected]>
  • Loading branch information
sebffischer and be-marc authored Jun 20, 2024
1 parent e9905f4 commit c566afb
Show file tree
Hide file tree
Showing 28 changed files with 406 additions and 137 deletions.
43 changes: 43 additions & 0 deletions .github/workflows/test-task-1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# r cmd check workflow of the mlr3 ecosystem v0.1.0
# https://github.com/mlr-org/actions
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
branches:
- main

name: mlr3 & mlr3pipelines change

jobs:
r-cmd-check:
runs-on: ${{ matrix.config.os }}

name: ${{ matrix.config.os }} (${{ matrix.config.r }})

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}

strategy:
fail-fast: false
matrix:
config:
- {os: ubuntu-latest, r: 'release'}

steps:
- uses: actions/checkout@v3

- name: mlr3
run: 'echo -e "Remotes:\n mlr-org/mlr3@feat/train-predict,\n mlr-org/mlr3pipelines$fixt/uses_test_task" >> DESCRIPTION'

- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
- uses: r-lib/actions/check-r-package@v2
43 changes: 43 additions & 0 deletions .github/workflows/test-task-2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# r cmd check workflow of the mlr3 ecosystem v0.1.0
# https://github.com/mlr-org/actions
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
branches:
- main

name: mlr3 & mlr3pipelines change

jobs:
r-cmd-check:
runs-on: ${{ matrix.config.os }}

name: ${{ matrix.config.os }} (${{ matrix.config.r }})

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}

strategy:
fail-fast: false
matrix:
config:
- {os: ubuntu-latest, r: 'release'}

steps:
- uses: actions/checkout@v3

- name: mlr3
run: 'echo -e "Remotes:\n mlr-org/mlr3@feat/train-predict,\n mlr-org/mlr3pipelines$feat/test-rows" >> DESCRIPTION'

- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
- uses: r-lib/actions/check-r-package@v2
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Imports:
checkmate,
data.table,
mlr3misc (>= 0.9.4),
paradox,
paradox (>= 1.0.0),
R6
Suggests:
DiceKriging,
Expand All @@ -45,6 +45,8 @@ Suggests:
rmarkdown,
testthat (>= 3.0.0),
xgboost (>= 1.6.0)
Remotes:
mlr-org/mlr3@feat/inner_valid
Config/testthat/edition: 3
Encoding: UTF-8
NeedsCompilation: no
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import(paradox)
importFrom(R6,R6Class)
importFrom(mlr3,LearnerClassif)
importFrom(mlr3,LearnerRegr)
importFrom(mlr3,assert_validate)
importFrom(mlr3,mlr_learners)
importFrom(stats,predict)
importFrom(stats,reformulate)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3learners (development version)

* feat: `LearnerClassifXgboost` and `LearnerRegrXgboost` now support internal tuning and validation.
This now also works in conjunction with `mlr3pipelines`.

# mlr3learners 0.6.0

* Adaption to new paradox version 1.0.0.
Expand Down
98 changes: 73 additions & 25 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
#'
#' @section Early stopping:
#' Early stopping can be used to find the optimal number of boosting rounds.
#' The `early_stopping_set` parameter controls which set is used to monitor the performance.
#' Set `early_stopping_set = "test"` to monitor the performance of the model on the test set while training.
#' The test set for early stopping can be set with the `"test"` row role in the [mlr3::Task].
#' Additionally, the range must be set in which the performance must increase with `early_stopping_rounds` and the maximum number of boosting rounds with `nrounds`.
#' While resampling, the test set is automatically applied from the [mlr3::Resampling].
#' Not that using the test set for early stopping can potentially bias the performance scores.
#' See the section on early stopping in the examples.
#' The `early_stopping` parameter controls which set is used to monitor the performance.
#' Set `early_stopping_rounds` to an integer vaulue to monitor the performance of the model on the validation set while training.
#' For information on how to configure the valdiation set, see the *Validation* section of [`mlr3::Learner`].
#'
#' @templateVar id classif.xgboost
#' @template learner
Expand All @@ -55,19 +51,24 @@
#' # Train learner with early stopping on spam data set
#' task = tsk("spam")
#'
#' # Split task into training and test set
#' # Split task into training and validation data
#' split = partition(task, ratio = 0.8)
#' task$set_row_roles(split$test, "test")
#' task$divide(split$test)
#'
#' task
#'
#' # Set early stopping parameter
#' learner = lrn("classif.xgboost",
#' nrounds = 100,
#' early_stopping_rounds = 10,
#' early_stopping_set = "test"
#' validate = "internal_valid"
#' )
#'
#' # Train learner with early stopping
#' learner$train(task)
#'
#' learner$internal_tuned_values
#' learner$internal_valid_scores
#' }
LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
inherit = LearnerClassif,
Expand All @@ -77,6 +78,16 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {

p_nrounds = p_int(1L,
tags = c("train", "hotstart", "internal_tuning"),
aggr = crate(function(x) as.integer(ceiling(mean(unlist(x)))), .parent = topenv()),
in_tune_fn = crate(function(domain, param_vals) {
assert_true(!is.null(param_vals$early_stopping), .var.name = "early stopping rounds is set")
assert_integerish(domain$upper, len = 1L, any.missing = FALSE) }, .parent = topenv()),
disable_in_tune = list(early_stopping_rounds = NULL)
)

ps = ps(
alpha = p_dbl(0, default = 0, tags = "train"),
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
Expand All @@ -89,7 +100,6 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
device = p_uty(default = "cpu", tags = "train"),
disable_default_eval_metric = p_lgl(default = FALSE, tags = "train"),
early_stopping_rounds = p_int(1L, default = NULL, special_vals = list(NULL), tags = "train"),
early_stopping_set = p_fct(c("none", "train", "test"), default = "none", tags = "train"),
eta = p_dbl(0, 1, default = 0.3, tags = c("train", "control")),
eval_metric = p_uty(tags = "train"),
feature_selector = p_fct(c("cyclic", "shuffle", "random", "greedy", "thrifty"), default = "cyclic", tags = "train", depends = quote(booster == "gblinear")),
Expand All @@ -108,8 +118,8 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
min_child_weight = p_dbl(0, default = 1, tags = c("train", "control")),
missing = p_dbl(default = NA, tags = c("train", "predict"), special_vals = list(NA, NA_real_, NULL)),
monotone_constraints = p_uty(default = 0, tags = c("train", "control"), custom_check = crate(function(x) { checkmate::check_integerish(x, lower = -1, upper = 1, any.missing = FALSE) })), # nolint
nrounds = p_nrounds,
normalize_type = p_fct(c("tree", "forest"), default = "tree", tags = "train", depends = quote(booster == "dart")),
nrounds = p_int(1L, tags = c("train", "hotstart")),
nthread = p_int(1L, default = 1L, tags = c("train", "control", "threads")),
ntreelimit = p_int(1L, default = NULL, special_vals = list(NULL), tags = "predict"),
num_parallel_tree = p_int(1L, default = 1L, tags = c("train", "control")),
Expand Down Expand Up @@ -141,18 +151,17 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
verbose = p_int(0L, 2L, default = 1L, tags = "train"),
watchlist = p_uty(default = NULL, tags = "train"),
xgb_model = p_uty(default = NULL, tags = "train")

)

# custom defaults
ps$values = list(nrounds = 1L, nthread = 1L, verbose = 0L, early_stopping_set = "none")
ps$values = list(nrounds = 1L, nthread = 1L, verbose = 0L)

super$initialize(
id = "classif.xgboost",
predict_types = c("response", "prob"),
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
properties = c("weights", "missings", "twoclass", "multiclass", "importance", "hotstart_forward"),
properties = c("weights", "missings", "twoclass", "multiclass", "importance", "hotstart_forward", "internal_tuning", "validation"),
packages = c("mlr3learners", "xgboost"),
label = "Extreme Gradient Boosting",
man = "mlr3learners::mlr_learners_classif.xgboost"
Expand All @@ -174,8 +183,30 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
set_names(imp$Gain, imp$Feature)
}
),

active = list(
#' @field internal_valid_scores
#' The last observation of the validation scores for all metrics.
#' Extracted from `model$evaluation_log`
internal_valid_scores = function() {
self$state$internal_valid_scores
},
#' @field internal_tuned_values
#' Returns the early stopped iterations if `early_stopping_rounds` was set during training.
internal_tuned_values = function() {
self$state$internal_tuned_values
},
#' @field validate
#' How to construct the internal validation data. This parameter can be either `NULL`,
#' a ratio, `"test"`, or `"predefined"`.
validate = function(rhs) {
if (!missing(rhs)) {
private$.validate = assert_validate(rhs)
}
private$.validate
}
),
private = list(
.validate = NULL,
.train = function(task) {

pv = self$param_set$get_values(tags = "train")
Expand Down Expand Up @@ -217,19 +248,18 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
xgboost::setinfo(data, "weight", task$weights$weight)
}

if (pv$early_stopping_set != "none") {
pv$watchlist = c(pv$watchlist, list(train = data))
}

# the last element in the watchlist is used as the early stopping set

if (pv$early_stopping_set == "test" && !is.null(task$row_roles$test)) {
test_data = task$data(rows = task$row_roles$test, cols = task$feature_names)
test_label = nlvls - as.integer(task$truth(rows = task$row_roles$test))
internal_valid_task = task$internal_valid_task
if (!is.null(pv$early_stopping_rounds) && is.null(internal_valid_task)) {
stopf("Learner (%s): Configure field 'validate' to enable early stopping.", self$id)
}
if (!is.null(internal_valid_task)) {
test_data = internal_valid_task$data(cols = internal_valid_task$feature_names)
test_label = nlvls - as.integer(internal_valid_task$truth(rows = internal_valid_task$row_roles$test))
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = test_label)
pv$watchlist = c(pv$watchlist, list(test = test_data))
}
pv$early_stopping_set = NULL

invoke(xgboost::xgb.train, data = data, .args = pv)
},
Expand Down Expand Up @@ -282,7 +312,6 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
if (!is.null(pars_train$early_stopping_rounds)) {
stop("The parameter `early_stopping_rounds` is set. Early stopping and hotstarting are incompatible.")
}
pars$early_stopping_set = NULL

# Calculate additional boosting iterations
# niter in model and nrounds in ps should be equal after train and continue
Expand All @@ -295,10 +324,29 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = label)

invoke(xgboost::xgb.train, data = data, xgb_model = model, .args = pars)
},

.extract_internal_tuned_values = function() {
if (is.null(self$state$param_vals$early_stopping_rounds)) {
return(named_list())
}
list(nrounds = self$model$niter)
},

.extract_internal_valid_scores = function() {
if (is.null(self$model$evaluation_log)) {
return(named_list())
}
as.list(self$model$evaluation_log[
get(".N"),
set_names(get(".SD"), gsub("^test_", "", colnames(get(".SD",)))),
.SDcols = patterns("^test_")
])
}
)
)


#' @export
default_values.LearnerClassifXgboost = function(x, search_space, task, ...) { # nolint
special_defaults = list(
Expand Down
Loading

0 comments on commit c566afb

Please sign in to comment.