Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow steps of >1 in orsf_vs #70

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions .github/workflows/draft-pdf.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# aorsf 0.1.5 (unreleased)
# aorsf 0.1.6 (unreleased)

* added `n_predictor_drop` to `orsf_vs()`. Dropping one predictor at a time makes `orsf_vs()` slow for data with hundreds of predictors. Using a larger value for `n_predictor_drop` helps speed this up. The default value of `n_predictor_drop` is 1 to maintain backward compatibility.

# aorsf 0.1.5

* fixed an issue where omitting NA values would cause an error in regression forests.

Expand Down
4 changes: 3 additions & 1 deletion R/coerce_nans.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#' @noRd
coerce_nans <- function(x, to){
UseMethod('coerce_nans')
}

#' @noRd
coerce_nans.list <- function(x, to){

lapply(x, coerce_nans, to = to)

}

#' @noRd
coerce_nans.factor <-
coerce_nans.integer <-
coerce_nans.double <-
Expand Down
36 changes: 27 additions & 9 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ ObliqueForest <- R6::R6Class(

# Variable selection
# returns a data.table with variable selection info
select_variables = function(n_predictor_min, verbose_progress){
select_variables = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

public_state <- list(verbose_progress = self$verbose_progress,
forest = self$forest,
Expand All @@ -712,7 +714,9 @@ ObliqueForest <- R6::R6Class(
object_trained <- self$trained

out <- try(
private$select_variables_internal(n_predictor_min, verbose_progress)
private$select_variables_internal(n_predictor_min,
n_predictor_drop,
verbose_progress)
)

private$restore_state(public_state, private_state = NULL)
Expand Down Expand Up @@ -2928,9 +2932,11 @@ ObliqueForest <- R6::R6Class(

},

select_variables_internal = function(n_predictor_min, verbose_progress){
select_variables_internal = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

n_predictors <- length(private$data_names$x_original)
n_predictors <- length(private$data_names$x_ref_code)

# verbose progress on the forest should always be FALSE
# because for orsf_vs, verbosity is coordinated in R
Expand All @@ -2941,7 +2947,7 @@ ObliqueForest <- R6::R6Class(
stat_value = rep(NA_real_, n_predictors),
variables_included = vector(mode = 'list', length = n_predictors),
predictors_included = vector(mode = 'list', length = n_predictors),
predictor_dropped = rep(NA_character_, n_predictors)
predictor_dropped = vector(mode = 'list', length = n_predictors)
)

# if the forest was not trained prior to variable selection
Expand Down Expand Up @@ -3045,9 +3051,21 @@ ObliqueForest <- R6::R6Class(
cpp_args$mtry <- mtry_safe
cpp_output <- do.call(orsf_cpp, args = cpp_args)

worst_index <- which.min(cpp_output$importance)
worst_predictor <- colnames(cpp_args$x)[worst_index]
n_drop <- min(n_predictor_drop,
n_predictors - n_predictor_min)

if(n_drop > 0){

worst_index <- order(cpp_output$importance)[seq(n_drop)]

worst_predictor <- colnames(cpp_args$x)[worst_index]

} else {

worst_predictor <- NA_character_
n_drop <- 1

}

.variables_included <- with(
variable_key,
Expand All @@ -3062,8 +3080,8 @@ ObliqueForest <- R6::R6Class(
predictor_dropped = worst_predictor)]

cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE]
n_predictors <- n_predictors - 1
current_progress <- current_progress + 1
n_predictors <- n_predictors - n_drop
current_progress <- current_progress + n_drop

}

Expand Down
5 changes: 4 additions & 1 deletion R/orsf_data_prep.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

#' @noRd
orsf_data_prep <- function(data, ...){
UseMethod('orsf_data_prep')
}

#' @noRd
orsf_data_prep.list <- function(data, ...){

lengths <- vapply(data, length, integer(1))
Expand Down Expand Up @@ -43,12 +44,14 @@ orsf_data_prep.list <- function(data, ...){

}

#' @noRd
orsf_data_prep.recipe <- function(data, ...){

getElement(data, 'template')

}

#' @noRd
orsf_data_prep.data.frame <- function(data, ...){
data
}
20 changes: 19 additions & 1 deletion R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#'
#' @inheritParams predict.ObliqueForest
#' @param n_predictor_min (*integer*) the minimum number of predictors allowed
#' @param n_predictor_drop (*integer*) the number of predictors dropped at each step
#' @param verbose_progress (*logical*) not implemented yet. Should progress be printed to the console?
#'
#' @return a [data.table][data.table::data.table-package] with four columns:
Expand Down Expand Up @@ -38,8 +39,15 @@

orsf_vs <- function(object,
n_predictor_min = 3,
n_predictor_drop = 1,
verbose_progress = NULL){

if(object$importance_type == 'none'){
stop("object must be specified with importance",
"of 'anova', 'negate', or 'permute'",
call. = FALSE)
}

check_arg_is(arg_value = object,
arg_name = 'object',
expected_class = 'ObliqueForest')
Expand All @@ -55,6 +63,14 @@ orsf_vs <- function(object,
arg_name = 'n_predictor_min',
bound = 1)


check_arg_type(arg_value = n_predictor_drop,
arg_name = 'n_predictor_drop',
expected_type = 'numeric')

check_arg_is_integer(arg_value = n_predictor_drop,
arg_name = 'n_predictor_drop')

check_arg_lt(arg_value = n_predictor_min,
arg_name = 'n_predictor_min',
bound = length(object$get_names_x()),
Expand All @@ -74,7 +90,9 @@ orsf_vs <- function(object,
arg_name = 'verbose_progress',
expected_length = 1)

object$select_variables(n_predictor_min, verbose_progress)
object$select_variables(n_predictor_min,
n_predictor_drop,
verbose_progress)

}

Expand Down
48 changes: 18 additions & 30 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_cph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_fast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_net.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading