Skip to content

Commit

Permalink
xgboostImpute update
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkowa committed Jun 21, 2023
1 parent 963a302 commit 497b373
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions R/xgboostImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
#' data(sleep)
#' xgboostImpute(Dream~BodyWgt+BrainWgt,data=sleep)
#' xgboostImpute(Dream+NonD~BodyWgt+BrainWgt,data=sleep)
#'
#' sleepx <- sleep
#' sleepx$Pred <- as.factor(sleepx$Pred)
#' sleepx$Pred[1] <- NA
#' @export
xgboostImpute <- function(formula, data, imp_var = TRUE,
imp_suffix = "imp", ..., verbose = FALSE,
nrounds=2, objective=NULL,
median = FALSE) {
median = FALSE){
check_data(data)
formchar <- as.character(formula)
lhs <- gsub(" ", "", strsplit(formchar[2], "\\+")[[1]])
Expand All @@ -32,11 +36,12 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
rhs_na <- apply(subset(data, select = rhs2), 1, function(x) any(is.na(x)))
#objective should be a vector of lenght equal to the lhs variables
if(!is.null(objective)){
stopfifnot(length(objective)!=length(lhs))
stopifnot(length(objective)!=length(lhs))
}
for (lhsV in lhs) {
form <- as.formula(paste(lhsV, "~", rhs))
lhs_vector <- data[[lhsV]]
num_class <- NULL
if (!any(is.na(lhs_vector))) {
cat(paste0("No missings in ", lhsV, ".\n"))
} else {
Expand All @@ -46,12 +51,13 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
dattmp <- subset(data, !rhs_na & !lhs_na)
labtmp <- dattmp[[lhsV]]
if(inherits(labtmp,"factor")){
labtmp <- as.integer(labtmp)
labtmp <- as.integer(labtmp)-1
if(length(unique(labtmp))==2){
objective <- "binary:logistic"
}else if(length(unique(labtmp))>2){
objective <- "mult:softmax"
objective <- "multi:softmax"
}
num_class <- max(labtmp)+1
}else if(inherits(labtmp,"numeric")){
if(length(unique(labtmp))==2){
warning("binary factor detected but not probably stored as factor.")
Expand All @@ -70,16 +76,16 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,


mm <- model.matrix(form,dattmp)
mod <- xgboost::xgboost(data=mm, label = labtmp,
nrounds=nrounds, objective=objective)
mod <- xgboost::xgboost(data = mm, label = labtmp,
nrounds=nrounds, objective=objective, num_class = num_class)
if (verbose)
message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations")
if (median & inherits(lhs_vector, "numeric")) {
predictions <- apply(
predict(mod, model.matrix(form,subset(data, !rhs_na & lhs_na)), predict.all = TRUE)$predictions,
1, median)
} else {
predictions <- predict(mod, subset(data, !rhs_na & lhs_na))$predictions
predictions <- predict(mod, model.matrix(as.formula(paste0("~",rhs)),subset(data, !rhs_na & lhs_na)))
}
data[!rhs_na & lhs_na, lhsV] <- predictions
}
Expand Down

0 comments on commit 497b373

Please sign in to comment.