Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Aug 16, 2024
1 parent e7db2c0 commit 0b19ba1
Show file tree
Hide file tree
Showing 13 changed files with 24 additions and 38 deletions.
3 changes: 1 addition & 2 deletions R/LearnerClassifCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
target = swap_levels(task$truth())
pv = self$param_set$get_values(tags = "train")
pv$family = ifelse(length(task$class_names) == 2L, "binomial", "multinomial")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

glmnet_invoke(data, target, pv, cv = TRUE)
},
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
target = swap_levels(task$truth())
pv = self$param_set$get_values(tags = "train")
pv$family = ifelse(length(task$class_names) == 2L, "binomial", "multinomial")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

glmnet_invoke(data, target, pv)
},
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerClassifMultinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ LearnerClassifMultinom = R6Class("LearnerClassifMultinom",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

if (!is.null(pv$summ)) {
pv$summ = as.integer(pv$summ)
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerClassifNnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

if (is.null(pv$formula)) {
pv$formula = task$formula()
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerClassifRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names))
pv$case.weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv, "case.weights")

invoke(ranger::ranger,
dependent.variable.name = task$target_names,
Expand Down
8 changes: 4 additions & 4 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
label = nlvls - as.integer(task$truth())
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = label)

weights = get_weights(task, pv)
if (!is.null(weights)) {
xgboost::setinfo(data, "weight", weights)
pv = get_weights(task, pv)
if (!is.null(pv$weights)) {
xgboost::setinfo(data, "weight", pv$weights)
pv$weights = NULL
}
pv = remove_named(pv, "use_weights")

# the last element in the watchlist is used as the early stopping set

Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
data = as_numeric_matrix(task$data(cols = task$feature_names))
target = as_numeric_matrix(task$data(cols = task$target_names))
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

glmnet_invoke(data, target, pv, cv = TRUE)
},
Expand Down
4 changes: 1 addition & 3 deletions R/LearnerRegrGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
type.multinomial = p_fct(c("ungrouped", "grouped"), tags = "train"),
upper.limits = p_uty(tags = "train"),
use_weights = p_lgl(default = FALSE, tags = "train")

)

ps$values = list(family = "gaussian")
Expand Down Expand Up @@ -104,8 +103,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
data = as_numeric_matrix(task$data(cols = task$feature_names))
target = as_numeric_matrix(task$data(cols = task$target_names))
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

glmnet_invoke(data, target, pv)
},
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ LearnerRegrLM = R6Class("LearnerRegrLM",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

invoke(stats::lm,
formula = task$formula(), data = task$data(),
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrNnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv)

if (is.null(pv$formula)) {
pv$formula = task$formula()
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names))
pv$case.weights = get_weights(task, pv)
pv = remove_named(pv, "use_weights")
pv = get_weights(task, pv, "case.weights")

if (self$predict_type == "se") {
pv$keep.inbag = TRUE # nolint
Expand Down
8 changes: 4 additions & 4 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
watchlist = p_uty(default = NULL, tags = "train"),
xgb_model = p_uty(default = NULL, tags = "train"),
use_weights = p_lgl(default = FALSE, tags = "train")

)
# param deps

Expand Down Expand Up @@ -202,9 +201,10 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
target = task$data(cols = task$target_names)
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target))

weights = get_weights(task, pv)
if (!is.null(weights)) {
xgboost::setinfo(data, "weight", weights)
pv = get_weights(task, pv)
if (!is.null(pv$weights)) {
xgboost::setinfo(data, "weight", pv$weights)
pv$weights = NULL
}

# the last element in the watchlist is used as the early stopping set
Expand Down
15 changes: 6 additions & 9 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,13 @@ extract_loglik = function(self) {
stats::logLik(self$model)
}

get_weights = function(task, pv) {
tmp = c("weights", "weights_learner") %in% task$properties

if (tmp[1L]) { # old mlr3 version, deprecated weights
task$weights$weight
} else if (tmp[2L] && isTRUE(pv$use_weights)) {
task$weights_learner$weight
} else {
NULL
get_weights = function(task, pv, name) {
if (isTRUE(pv$use_weights)) {
pv[[name]] = task$weights$weight
}
pv[["use_weights"]] = NULL

return(pv)
}

opts_default_contrasts = list(contrasts = c("contr.treatment", "contr.poly"))

0 comments on commit 0b19ba1

Please sign in to comment.