From 0b19ba1a5c481f4f5829193b98973dfd9a55a0f8 Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Fri, 16 Aug 2024 10:48:19 +0200 Subject: [PATCH] cleanup --- R/LearnerClassifCVGlmnet.R | 3 +-- R/LearnerClassifGlmnet.R | 3 +-- R/LearnerClassifMultinom.R | 3 +-- R/LearnerClassifNnet.R | 3 +-- R/LearnerClassifRanger.R | 3 +-- R/LearnerClassifXgboost.R | 8 ++++---- R/LearnerRegrCVGlmnet.R | 3 +-- R/LearnerRegrGlmnet.R | 4 +--- R/LearnerRegrLM.R | 3 +-- R/LearnerRegrNnet.R | 3 +-- R/LearnerRegrRanger.R | 3 +-- R/LearnerRegrXgboost.R | 8 ++++---- R/helpers.R | 15 ++++++--------- 13 files changed, 24 insertions(+), 38 deletions(-) diff --git a/R/LearnerClassifCVGlmnet.R b/R/LearnerClassifCVGlmnet.R index 4512076b..1867e844 100644 --- a/R/LearnerClassifCVGlmnet.R +++ b/R/LearnerClassifCVGlmnet.R @@ -105,8 +105,7 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet", target = swap_levels(task$truth()) pv = self$param_set$get_values(tags = "train") pv$family = ifelse(length(task$class_names) == 2L, "binomial", "multinomial") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) glmnet_invoke(data, target, pv, cv = TRUE) }, diff --git a/R/LearnerClassifGlmnet.R b/R/LearnerClassifGlmnet.R index 2fc3244c..a20127f6 100644 --- a/R/LearnerClassifGlmnet.R +++ b/R/LearnerClassifGlmnet.R @@ -115,8 +115,7 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet", target = swap_levels(task$truth()) pv = self$param_set$get_values(tags = "train") pv$family = ifelse(length(task$class_names) == 2L, "binomial", "multinomial") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) glmnet_invoke(data, target, pv) }, diff --git a/R/LearnerClassifMultinom.R b/R/LearnerClassifMultinom.R index 08528c44..94fda130 100644 --- a/R/LearnerClassifMultinom.R +++ b/R/LearnerClassifMultinom.R @@ -64,8 +64,7 @@ LearnerClassifMultinom = R6Class("LearnerClassifMultinom", private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) if (!is.null(pv$summ)) { pv$summ = as.integer(pv$summ) diff --git a/R/LearnerClassifNnet.R b/R/LearnerClassifNnet.R index 86f011a4..ed13401b 100644 --- a/R/LearnerClassifNnet.R +++ b/R/LearnerClassifNnet.R @@ -72,8 +72,7 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet", private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) if (is.null(pv$formula)) { pv$formula = task$formula() diff --git a/R/LearnerClassifRanger.R b/R/LearnerClassifRanger.R index f0503356..cd7b82b8 100644 --- a/R/LearnerClassifRanger.R +++ b/R/LearnerClassifRanger.R @@ -117,8 +117,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger", .train = function(task) { pv = self$param_set$get_values(tags = "train") pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names)) - pv$case.weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv, "case.weights") invoke(ranger::ranger, dependent.variable.name = task$target_names, diff --git a/R/LearnerClassifXgboost.R b/R/LearnerClassifXgboost.R index 9ef6dc3d..914456f3 100644 --- a/R/LearnerClassifXgboost.R +++ b/R/LearnerClassifXgboost.R @@ -246,11 +246,11 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost", label = nlvls - as.integer(task$truth()) data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = label) - weights = get_weights(task, pv) - if (!is.null(weights)) { - xgboost::setinfo(data, "weight", weights) + pv = get_weights(task, pv) + if (!is.null(pv$weights)) { + xgboost::setinfo(data, "weight", pv$weights) + pv$weights = NULL } - pv = remove_named(pv, "use_weights") # the last element in the watchlist is used as the early stopping set diff --git a/R/LearnerRegrCVGlmnet.R b/R/LearnerRegrCVGlmnet.R index d1b01857..e60978cd 100644 --- a/R/LearnerRegrCVGlmnet.R +++ b/R/LearnerRegrCVGlmnet.R @@ -103,8 +103,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet", data = as_numeric_matrix(task$data(cols = task$feature_names)) target = as_numeric_matrix(task$data(cols = task$target_names)) pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) glmnet_invoke(data, target, pv, cv = TRUE) }, diff --git a/R/LearnerRegrGlmnet.R b/R/LearnerRegrGlmnet.R index f5047dba..9783cf41 100644 --- a/R/LearnerRegrGlmnet.R +++ b/R/LearnerRegrGlmnet.R @@ -70,7 +70,6 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet", type.multinomial = p_fct(c("ungrouped", "grouped"), tags = "train"), upper.limits = p_uty(tags = "train"), use_weights = p_lgl(default = FALSE, tags = "train") - ) ps$values = list(family = "gaussian") @@ -104,8 +103,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet", data = as_numeric_matrix(task$data(cols = task$feature_names)) target = as_numeric_matrix(task$data(cols = task$target_names)) pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) glmnet_invoke(data, target, pv) }, diff --git a/R/LearnerRegrLM.R b/R/LearnerRegrLM.R index da744206..767e328d 100644 --- a/R/LearnerRegrLM.R +++ b/R/LearnerRegrLM.R @@ -62,8 +62,7 @@ LearnerRegrLM = R6Class("LearnerRegrLM", private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) invoke(stats::lm, formula = task$formula(), data = task$data(), diff --git a/R/LearnerRegrNnet.R b/R/LearnerRegrNnet.R index 513ddfac..7cc6fceb 100644 --- a/R/LearnerRegrNnet.R +++ b/R/LearnerRegrNnet.R @@ -71,8 +71,7 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet", private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") - pv$weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv) if (is.null(pv$formula)) { pv$formula = task$formula() diff --git a/R/LearnerRegrRanger.R b/R/LearnerRegrRanger.R index 55af49e6..4dc024d6 100644 --- a/R/LearnerRegrRanger.R +++ b/R/LearnerRegrRanger.R @@ -108,8 +108,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger", .train = function(task) { pv = self$param_set$get_values(tags = "train") pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names)) - pv$case.weights = get_weights(task, pv) - pv = remove_named(pv, "use_weights") + pv = get_weights(task, pv, "case.weights") if (self$predict_type == "se") { pv$keep.inbag = TRUE # nolint diff --git a/R/LearnerRegrXgboost.R b/R/LearnerRegrXgboost.R index c3e669a8..74ad885b 100644 --- a/R/LearnerRegrXgboost.R +++ b/R/LearnerRegrXgboost.R @@ -129,7 +129,6 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", watchlist = p_uty(default = NULL, tags = "train"), xgb_model = p_uty(default = NULL, tags = "train"), use_weights = p_lgl(default = FALSE, tags = "train") - ) # param deps @@ -202,9 +201,10 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", target = task$data(cols = task$target_names) data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target)) - weights = get_weights(task, pv) - if (!is.null(weights)) { - xgboost::setinfo(data, "weight", weights) + pv = get_weights(task, pv) + if (!is.null(pv$weights)) { + xgboost::setinfo(data, "weight", pv$weights) + pv$weights = NULL } # the last element in the watchlist is used as the early stopping set diff --git a/R/helpers.R b/R/helpers.R index 7f0b6817..584eb0df 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -44,16 +44,13 @@ extract_loglik = function(self) { stats::logLik(self$model) } -get_weights = function(task, pv) { - tmp = c("weights", "weights_learner") %in% task$properties - - if (tmp[1L]) { # old mlr3 version, deprecated weights - task$weights$weight - } else if (tmp[2L] && isTRUE(pv$use_weights)) { - task$weights_learner$weight - } else { - NULL +get_weights = function(task, pv, name) { + if (isTRUE(pv$use_weights)) { + pv[[name]] = task$weights$weight } + pv[["use_weights"]] = NULL + + return(pv) } opts_default_contrasts = list(contrasts = c("contr.treatment", "contr.poly"))