Skip to content

Commit

Permalink
small changes to make the xgboostImpute version closer to CRAN readiness
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkowa committed Nov 9, 2023
1 parent 7a6cfa3 commit cb5a541
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 14 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Imports:
laeken,
ranger,
MASS,
xgboost,
data.table(>= 1.9.4)
Suggests:
dplyr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export(scattMiss)
export(scattmatrixMiss)
export(spineMiss)
export(tableMiss)
export(xgboostImpute)
import(Rcpp)
import(colorspace)
import(data.table)
Expand Down
32 changes: 21 additions & 11 deletions R/xgboostImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
#' @param imp_var `TRUE`/`FALSE` if a `TRUE`/`FALSE` variables for each imputed
#' variable should be created show the imputation status
#' @param imp_suffix suffix used for TF imputation variables
#' @param ... Arguments passed to [xgboost::xgboost()]
#' @param verbose Show the number of observations used for training
#' and evaluating the RF-Model. This parameter is also passed down to
#' [xgboost::xgboost()] to show computation status.
#' @param ... Arguments passed to [xgboost::xgboost()]
#' @param nrounds max number of boosting iterations,
#' argument passed to [xgboost::xgboost()]
#' @param objective objective for xgboost,
#' argument passed to [xgboost::xgboost()]
#' @return the imputed data set.
#' @family imputation methods
#' @examples
Expand Down Expand Up @@ -63,32 +67,36 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
num_class <- max(labtmp)+1
}

}else if(inherits(labtmp,"numeric")){
currentClass <- "numeric"
}else if(inherits(labtmp,"integer")){
currentClass <- "integer"
if(length(unique(labtmp))==2){
warning("binary factor detected but not probably stored as factor.")
lvlsInt <- unique(labtmp)
labtmp <- match(labtmp,lvlsInt)-1
warning("binary factor detected but not probproperlyably stored as factor.")
objective <- "binary:logistic"
}else{
objective <- "reg:squarederror"
objective <- "count:poisson"## Todo: this might not be wise as default
}
}else if(inherits(labtmp,"integer")){
currentClass <- "integer"
}else if(inherits(labtmp,"numeric")){
currentClass <- "numeric"
if(length(unique(labtmp))==2){
warning("binary factor detected but not probably stored as factor.")
lvlsInt <- unique(labtmp)
labtmp <- match(labtmp,lvlsInt)-1
warning("binary factor detected but not properly stored as factor.")
objective <- "binary:logistic"
}else{
objective <- "count:poisson"## Todo: this might not be wise as default
objective <- "reg:squarederror"
}
}


mm <- model.matrix(form,dattmp)
if(!is.null(num_class)){
mod <- xgboost::xgboost(data = mm, label = labtmp,
nrounds=nrounds, objective=objective, num_class = num_class, verbose = FALSE,...)
nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose,...)
}else{
mod <- xgboost::xgboost(data = mm, label = labtmp,
nrounds=nrounds, objective=objective, verbose = FALSE,...)
nrounds=nrounds, objective=objective, verbose = verbose,...)
}

if (verbose)
Expand All @@ -101,6 +109,8 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
}else{
data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[predictions+1]
}
}else if(currentClass%in%c("numeric","integer")&objective=="binary:logistic"){
data[!rhs_na & lhs_na, lhsV] <- lvlsInt[as.numeric(predictions>.5)+1]
}else{
data[!rhs_na & lhs_na, lhsV] <- predictions
}
Expand Down
4 changes: 4 additions & 0 deletions inst/tinytest/test_matchImpute.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
library(VIM)
message("matchImpute general")
setna <- function(d,i,col=2){
d[i,col] <- NA
d
}
d <- data.frame(x=LETTERS[1:6],y=as.double(1:6),z=as.double(1:6),
w=ordered(LETTERS[1:6]), stringsAsFactors = FALSE)
dorig <- rbind(d,d)
Expand Down
14 changes: 11 additions & 3 deletions inst/tinytest/test_xgboostImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ max_dist <- function(x, y) {

df$y[1:3] <- NA
df$fac[3:5] <- NA

df$binNum <- as.integer(df$fac)+17
df$binInt <- as.integer(df$fac)+17L
# xgboostImpute accuracy", {
df.out <- xgboostImpute(y ~ x, df)
expect_true(
Expand All @@ -28,10 +29,17 @@ df$fac[3:5] <- NA

# factor response predicted accurately", {
df.out <- xgboostImpute(fac ~ x, df)
df.out[df.out$fac_imp,]
expect_identical(df.out$fac, as.factor(df$x >= 0))
#


# interger binary response predicted accurately", {
expect_warning(df.out <- xgboostImpute(binInt ~ x, df))
expect_identical(df.out$binInt==19, df$x >= 0)
#
# numeric binary response predicted accurately", {
expect_warning(df.out <- xgboostImpute(binNum ~ x, df))
expect_identical(df.out$binNum==19, df$x >= 0)
#
# factor regressor used reasonably", {
df2 <- df
df2$x[1:10] <- NA
Expand Down

0 comments on commit cb5a541

Please sign in to comment.