From 8e74b8fe7e296b7d240a145c747fb37fe2cf7a17 Mon Sep 17 00:00:00 2001 From: Johannes Gussenbauer - QM Date: Tue, 2 Apr 2024 14:34:19 +0200 Subject: [PATCH] fix bug in xgboostImpute() for classifying categorical variable with containting 2 value; added drawing from probability for xgboostImpute() --- R/xgboostImpute.R | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/R/xgboostImpute.R b/R/xgboostImpute.R index b2923a0..01ffaab 100644 --- a/R/xgboostImpute.R +++ b/R/xgboostImpute.R @@ -59,11 +59,16 @@ xgboostImpute <- function(formula, data, imp_var = TRUE, currentClass <- NULL if(inherits(labtmp,"factor")){ currentClass <- "factor" + + predict_levels <- levels(labtmp) labtmp <- as.integer(labtmp)-1 if(length(unique(labtmp))==2){ objective <- "binary:logistic" + predict_levels <- predict_levels[unique(labtmp)+1] + labtmp <- as.integer(as.factor(labtmp))-1 + }else if(length(unique(labtmp))>2){ - objective <- "multi:softmax" + objective <- "multi:softprob" num_class <- max(labtmp)+1 } @@ -72,7 +77,7 @@ xgboostImpute <- function(formula, data, imp_var = TRUE, if(length(unique(labtmp))==2){ lvlsInt <- unique(labtmp) labtmp <- match(labtmp,lvlsInt)-1 - warning("binary factor detected but not probproperlyably stored as factor.") + warning("binary factor detected but not properly stored as factor.") objective <- "binary:logistic" }else{ objective <- "count:poisson"## Todo: this might not be wise as default @@ -93,27 +98,38 @@ xgboostImpute <- function(formula, data, imp_var = TRUE, mm <- model.matrix(form,dattmp) if(!is.null(num_class)){ mod <- xgboost::xgboost(data = mm, label = labtmp, - nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose,...) + nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose, ...) }else{ mod <- xgboost::xgboost(data = mm, label = labtmp, - nrounds=nrounds, objective=objective, verbose = verbose,...) + nrounds=nrounds, objective=objective, verbose = verbose, ...) } if (verbose) message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations") + predictions <- - predict(mod, model.matrix(formPred,subset(data, !rhs_na & lhs_na))) - if(currentClass=="factor"){ - if(is.null(num_class)){ - data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[as.numeric(predictions>.5)+1] + predict(mod, newdata = model.matrix(formPred,subset(data, !rhs_na & lhs_na)), reshape=TRUE) + + if(objective %in% c("binary:logistic","multi:softprob")){ + + if(objective =="binary:logistic"){ + predictions <- cbind(1-predictions,predictions) + } + + predict_num <- 1:ncol(predictions) + predictions <- apply(predictions,1,function(z,lev){ + z <- cumsum(z) + z_lev <- lev[z>runif(1)] + return(z_lev[1]) + },lev=predict_num) + + if(is.factor(dattmp[[lhsV]])){ + predictions <- predict_levels[predictions] }else{ - data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[predictions+1] + predictions <- lvlsInt[predictions] } - }else if(currentClass%in%c("numeric","integer")&objective=="binary:logistic"){ - data[!rhs_na & lhs_na, lhsV] <- lvlsInt[as.numeric(predictions>.5)+1] - }else{ - data[!rhs_na & lhs_na, lhsV] <- predictions } + data[!rhs_na & lhs_na, ][[lhsV]] <- predictions }