Skip to content

Commit

Permalink
fix bug in xgboostImpute() for classifying categorical variable with …
Browse files Browse the repository at this point in the history
…containting 2 value; added drawing from probability for xgboostImpute()
  • Loading branch information
Johannes Gussenbauer - QM committed Apr 2, 2024
1 parent 615e380 commit 8e74b8f
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions R/xgboostImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,16 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
currentClass <- NULL
if(inherits(labtmp,"factor")){
currentClass <- "factor"

predict_levels <- levels(labtmp)
labtmp <- as.integer(labtmp)-1
if(length(unique(labtmp))==2){
objective <- "binary:logistic"
predict_levels <- predict_levels[unique(labtmp)+1]
labtmp <- as.integer(as.factor(labtmp))-1

}else if(length(unique(labtmp))>2){
objective <- "multi:softmax"
objective <- "multi:softprob"
num_class <- max(labtmp)+1
}

Expand All @@ -72,7 +77,7 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
if(length(unique(labtmp))==2){
lvlsInt <- unique(labtmp)
labtmp <- match(labtmp,lvlsInt)-1
warning("binary factor detected but not probproperlyably stored as factor.")
warning("binary factor detected but not properly stored as factor.")
objective <- "binary:logistic"
}else{
objective <- "count:poisson"## Todo: this might not be wise as default
Expand All @@ -93,27 +98,38 @@ xgboostImpute <- function(formula, data, imp_var = TRUE,
mm <- model.matrix(form,dattmp)
if(!is.null(num_class)){
mod <- xgboost::xgboost(data = mm, label = labtmp,
nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose,...)
nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose, ...)
}else{
mod <- xgboost::xgboost(data = mm, label = labtmp,
nrounds=nrounds, objective=objective, verbose = verbose,...)
nrounds=nrounds, objective=objective, verbose = verbose, ...)
}

if (verbose)
message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations")

predictions <-
predict(mod, model.matrix(formPred,subset(data, !rhs_na & lhs_na)))
if(currentClass=="factor"){
if(is.null(num_class)){
data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[as.numeric(predictions>.5)+1]
predict(mod, newdata = model.matrix(formPred,subset(data, !rhs_na & lhs_na)), reshape=TRUE)

if(objective %in% c("binary:logistic","multi:softprob")){

if(objective =="binary:logistic"){
predictions <- cbind(1-predictions,predictions)
}

predict_num <- 1:ncol(predictions)
predictions <- apply(predictions,1,function(z,lev){
z <- cumsum(z)
z_lev <- lev[z>runif(1)]
return(z_lev[1])
},lev=predict_num)

if(is.factor(dattmp[[lhsV]])){
predictions <- predict_levels[predictions]
}else{
data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[predictions+1]
predictions <- lvlsInt[predictions]
}
}else if(currentClass%in%c("numeric","integer")&objective=="binary:logistic"){
data[!rhs_na & lhs_na, lhsV] <- lvlsInt[as.numeric(predictions>.5)+1]
}else{
data[!rhs_na & lhs_na, lhsV] <- predictions
}
data[!rhs_na & lhs_na, ][[lhsV]] <- predictions

}

Expand Down

0 comments on commit 8e74b8f

Please sign in to comment.