Skip to content

Commit

Permalink
Merge pull request #296 from mlr-org/xgb
Browse files Browse the repository at this point in the history
xgb tuning / validation
  • Loading branch information
sebffischer authored Jun 25, 2024
2 parents bdd1bc0 + 9626674 commit e54589e
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 59 deletions.
43 changes: 22 additions & 21 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#'
#' Note that using the `watchlist` parameter directly will lead to problems when wrapping this [`Learner`] in a
#' `mlr3pipelines` `GraphLearner` as the preprocessing steps will not be applied to the data in the watchlist.
#' See the section *Early Stopping and Validation* on how to do this.
#'
#' @template note_xgboost
#' @section Initial parameter values:
Expand All @@ -30,11 +31,12 @@
#' - Adjusted default: 0.
#' - Reason for change: Reduce verbosity.
#'
#' @section Early stopping:
#' Early stopping can be used to find the optimal number of boosting rounds.
#' Set `early_stopping_rounds` to an integer vaulue to monitor the performance of the model on the validation set while training.
#' @section Early Stopping and Validation:
#' In order to monitor the validation performance during the training, you can set the `$validate` field of the Learner.
#' For information on how to configure the valdiation set, see the *Validation* section of [`mlr3::Learner`].
#'
#' This validation data can also be used for early stopping, which can be enabled by setting the `early_stopping_rounds` parameter.
#' The final (or in the case of early stopping best) validation scores can be accessed via `$internal_valid_scores`, and the
#' optimal `nrounds` via `$internal_tuned_values`.
#' @templateVar id classif.xgboost
#' @template learner
#'
Expand All @@ -50,22 +52,18 @@
#' # Train learner with early stopping on spam data set
#' task = tsk("spam")
#'
#' # Split task into training and validation data
#' split = partition(task, ratio = 0.8)
#' task$divide(split$test)
#'
#' task
#'
#' # use 30 percent for validation
#' # Set early stopping parameter
#' learner = lrn("classif.xgboost",
#' nrounds = 100,
#' early_stopping_rounds = 10,
#' validate = "internal_valid"
#' validate = 0.3
#' )
#'
#' # Train learner with early stopping
#' learner$train(task)
#'
#' # Inspect optimal nrounds and validation performance
#' learner$internal_tuned_values
#' learner$internal_valid_scores
#' }
Expand Down Expand Up @@ -185,18 +183,20 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
}
),
active = list(
#' @field internal_valid_scores
#' The last observation of the validation scores for all metrics.
#' Extracted from `model$evaluation_log`
#' @field internal_valid_scores (named `list()` or `NULL`)
#' The validation scores extracted from `model$evaluation_log`.
#' If early stopping is activated, this contains the validation scores of the model for the optimal `nrounds`,
#' otherwise the `nrounds` for the final model.
internal_valid_scores = function() {
self$state$internal_valid_scores
},
#' @field internal_tuned_values
#' Returns the early stopped iterations if `early_stopping_rounds` was set during training.
#' @field internal_tuned_values (named `list()` or `NULL`)
#' If early stopping is activated, this returns a list with `nrounds`,
#' which is extracted from `$best_iteration` of the model and otherwise `NULL`.
internal_tuned_values = function() {
self$state$internal_tuned_values
},
#' @field validate
#' @field validate (`numeric(1)` or `character(1)` or `NULL`)
#' How to construct the internal validation data. This parameter can be either `NULL`,
#' a ratio, `"test"`, or `"predefined"`.
validate = function(rhs) {
Expand Down Expand Up @@ -329,17 +329,18 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",

.extract_internal_tuned_values = function() {
if (is.null(self$state$param_vals$early_stopping_rounds)) {
return(named_list())
return(NULL)
}
list(nrounds = self$model$niter)
list(nrounds = self$model$best_iteration)
},

.extract_internal_valid_scores = function() {
if (is.null(self$model$evaluation_log)) {
return(named_list())
NULL
}
iter = if (!is.null(self$model$best_iteration)) self$model$best_iteration else self$model$niter
as.list(self$model$evaluation_log[
get(".N"),
iter,
set_names(get(".SD"), gsub("^test_", "", colnames(get(".SD",)))),
.SDcols = patterns("^test_")
])
Expand Down
32 changes: 20 additions & 12 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
#'
#' Note that using the `watchlist` parameter directly will lead to problems when wrapping this [`Learner`] in a
#' `mlr3pipelines` `GraphLearner` as the preprocessing steps will not be applied to the data in the watchlist.
#' See the section *Early Stopping and Validation* on how to do this.
#'
#' @template note_xgboost
#' @inheritSection mlr_learners_classif.xgboost Early stopping
#' @inheritSection mlr_learners_classif.xgboost Early Stopping and Validation
#' @inheritSection mlr_learners_classif.xgboost Initial parameter values
#'
#' @templateVar id regr.xgboost
Expand Down Expand Up @@ -41,6 +42,10 @@
#'
#' # Train learner with early stopping
#' learner$train(task)
#'
#' # Inspect optimal nrounds and validation performance
#' learner$internal_tuned_values
#' learner$internal_valid_scores
#' }
LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
inherit = LearnerRegr,
Expand Down Expand Up @@ -157,20 +162,22 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
),

active = list(
#' @field internal_valid_scores (named `list()`)\cr
#' The last observation of the validation scores for all metrics.
#' Extracted from `model$evaluation_log`
#' @field internal_valid_scores (named `list()` or `NULL`)
#' The validation scores extracted from `model$evaluation_log`.
#' If early stopping is activated, this contains the validation scores of the model for the optimal `nrounds`,
#' otherwise the `nrounds` for the final model.
internal_valid_scores = function() {
self$state$internal_valid_scores
},
#' @field internal_tuned_values (named `list()`)\cr
#' Returns the early stopped iterations if `early_stopping_rounds` was set during training.
#' @field internal_tuned_values (named `list()` or `NULL`)
#' If early stopping is activated, this returns a list with `nrounds`,
#' which is extracted from `$best_iteration` of the model and otherwise `NULL`.
internal_tuned_values = function() {
self$state$internal_tuned_values
},
#' @field validate
#' @field validate (`numeric(1)` or `character(1)` or `NULL`)
#' How to construct the internal validation data. This parameter can be either `NULL`,
#' a ratio, `"test"`, or `"internal_valid"`.
#' a ratio, `"test"`, or `"predefined"`.
validate = function(rhs) {
if (!missing(rhs)) {
private$.validate = assert_validate(rhs)
Expand Down Expand Up @@ -212,7 +219,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",

invoke(xgboost::xgb.train, data = data, .args = pv)
},

#' Returns the `$best_iteration` when early stopping is activated.
.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
model = self$model
Expand Down Expand Up @@ -244,17 +251,18 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",

.extract_internal_tuned_values = function() {
if (is.null(self$state$param_vals$early_stopping_rounds)) {
return(named_list())
return(NULL)
}
list(nrounds = self$model$niter)
list(nrounds = self$model$best_iteration)
},

.extract_internal_valid_scores = function() {
if (is.null(self$model$evaluation_log)) {
return(named_list())
}
iter = if (!is.null(self$model$best_iteration)) self$model$best_iteration else self$model$niter
as.list(self$model$evaluation_log[
get(".N"),
iter,
set_names(get(".SD"), gsub("^test_", "", colnames(get(".SD",)))),
.SDcols = patterns("^test_")
])
Expand Down
32 changes: 18 additions & 14 deletions man/mlr_learners_classif.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 21 additions & 10 deletions man/mlr_learners_regr.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion tests/testthat/test_classif_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ test_that("validation and inner tuning", {
early_stopping_rounds = NULL
)
learner$train(task)
expect_equal(learner$internal_tuned_values, named_list())
expect_equal(learner$internal_tuned_values, NULL)
expect_named(learner$model$evaluation_log, c("iter", "test_logloss"))
expect_list(learner$internal_valid_scores, types = "numeric")
expect_equal(names(learner$internal_valid_scores), "logloss")
Expand All @@ -104,4 +104,23 @@ test_that("validation and inner tuning", {
learner$param_set$set_values(early_stopping_rounds = 10)
learner$param_set$disable_internal_tuning("nrounds")
expect_equal(learner$param_set$values$early_stopping_rounds, NULL)

learner = lrn("classif.xgboost",
nrounds = 100,
early_stopping_rounds = 5,
validate = 0.3
)
learner$train(task)
expect_equal(learner$internal_valid_scores$logloss,
learner$model$evaluation_log$test_logloss[learner$internal_tuned_values$nrounds])

learner = lrn("classif.xgboost")
learner$train(task)
expect_true(is.null(learner$internal_valid_scores))
expect_true(is.null(learner$internal_tuned_values))

learner = lrn("classif.xgboost", validate = 0.3, nrounds = 10)
learner$train(task)
expect_equal(learner$internal_valid_scores$logloss, learner$model$evaluation_log$test_logloss[10L])
expect_true(is.null(learner$internal_tuned_values))
})
21 changes: 20 additions & 1 deletion tests/testthat/test_regr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ test_that("validation and inner tuning", {
early_stopping_rounds = NULL
)
learner$train(task)
expect_equal(learner$internal_tuned_values, named_list())
expect_equal(learner$internal_tuned_values, NULL)
expect_named(learner$model$evaluation_log, c("iter", "test_rmse"))
expect_list(learner$internal_valid_scores, types = "numeric")
expect_equal(names(learner$internal_valid_scores), "rmse")
Expand All @@ -85,4 +85,23 @@ test_that("validation and inner tuning", {
learner$param_set$set_values(early_stopping_rounds = 10)
learner$param_set$disable_internal_tuning("nrounds")
expect_equal(learner$param_set$values$early_stopping_rounds, NULL)

learner = lrn("regr.xgboost",
nrounds = 100L,
early_stopping_rounds = 5,
validate = 0.2
)
learner$train(task)
expect_equal(learner$internal_valid_scores$rmse,
learner$model$evaluation_log$test_rmse[learner$internal_tuned_values$nrounds])

learner = lrn("regr.xgboost")
learner$train(task)
expect_true(is.null(learner$internal_valid_scores))
expect_true(is.null(learner$internal_tuned_values))

learner = lrn("regr.xgboost", validate = 0.3, nrounds = 10)
learner$train(task)
expect_equal(learner$internal_valid_scores$rmse, learner$model$evaluation_log$test_rmse[10L])
expect_true(is.null(learner$internal_tuned_values))
})

0 comments on commit e54589e

Please sign in to comment.