Skip to content

Commit

Permalink
set nthread = 1 in all xgboost functions
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Aug 15, 2023
1 parent 608be1c commit 1a04ccc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
18 changes: 12 additions & 6 deletions R/discretize_xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,15 @@ xgb_binning <- function(df, outcome, predictor, sample_val, learn_rate,
xgb_train <- xgboost::xgb.DMatrix(
data = as.matrix(train[[predictor]], ncol = 1),
label = train[[outcome]],
weight = wts_train
weight = wts_train,
nthread = 1
)

xgb_test <- xgboost::xgb.DMatrix(
data = as.matrix(test[[predictor]], ncol = 1),
label = test[[outcome]],
weight = wts_test
weight = wts_test,
nthread = 1
)
} else {
if (length(levels) == 2) {
Expand All @@ -264,13 +266,15 @@ xgb_binning <- function(df, outcome, predictor, sample_val, learn_rate,
xgb_train <- xgboost::xgb.DMatrix(
data = as.matrix(train[[predictor]], ncol = 1),
label = ifelse(train[[outcome]] == levels[[1]], 0, 1),
weight = wts_train
weight = wts_train,
nthread = 1
)

xgb_test <- xgboost::xgb.DMatrix(
data = as.matrix(test[[predictor]], ncol = 1),
label = ifelse(test[[outcome]] == levels[[1]], 0, 1),
weight = wts_test
weight = wts_test,
nthread = 1
)
} else if (length(levels) >= 3) {
objective <- "multi:softprob" # returning estimated probability
Expand All @@ -279,13 +283,15 @@ xgb_binning <- function(df, outcome, predictor, sample_val, learn_rate,
xgb_train <- xgboost::xgb.DMatrix(
data = as.matrix(train[[predictor]], ncol = 1),
label = train[[outcome]],
weight = wts_train
weight = wts_train,
nthread = 1
)

xgb_test <- xgboost::xgb.DMatrix(
data = as.matrix(test[[predictor]], ncol = 1),
label = test[[outcome]],
weight = wts_test
weight = wts_test,
nthread = 1
)
} else {
rlang::abort(
Expand Down
18 changes: 12 additions & 6 deletions tests/testthat/test-discretize_xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ rec_credit <- credit_data_train %>%

xgb_credit_train <- xgboost::xgb.DMatrix(
data = as.matrix(bake(rec_credit, new_data = NULL)),
label = ifelse(credit_data_train[["Status"]] == "bad", 0, 1)
label = ifelse(credit_data_train[["Status"]] == "bad", 0, 1),
nthread = 1
)

xgb_credit_test <- xgboost::xgb.DMatrix(
data = as.matrix(bake(rec_credit, new_data = credit_data_test)),
label = ifelse(credit_data_test[["Status"]] == "bad", 0, 1)
label = ifelse(credit_data_test[["Status"]] == "bad", 0, 1),
nthread = 1
)

# Data for multi-classification problem testing
Expand All @@ -55,12 +57,14 @@ rec_attrition <- attrition_data_train %>%

xgb_attrition_train <- xgboost::xgb.DMatrix(
data = as.matrix(bake(rec_attrition, new_data = NULL)),
label = attrition_data_train$EducationField
label = attrition_data_train$EducationField,
nthread = 1
)

xgb_attrition_test <- xgboost::xgb.DMatrix(
data = as.matrix(bake(rec_attrition, new_data = attrition_data_test)),
label = attrition_data_test$EducationField
label = attrition_data_test$EducationField,
nthread = 1
)

ames$Sale_Price <- log10(ames$Sale_Price)
Expand All @@ -81,12 +85,14 @@ ames_rec <- ames_data_train %>%

xgb_ames_train <- xgboost::xgb.DMatrix(
data = as.matrix(bake(ames_rec, new_data = NULL)),
label = ames_data_train[["Sale_Price"]]
label = ames_data_train[["Sale_Price"]],
nthread = 1
)

xgb_ames_test <- xgboost::xgb.DMatrix(
data = as.matrix(bake(ames_rec, new_data = ames_data_test)),
label = ames_data_test[["Sale_Price"]]
label = ames_data_test[["Sale_Price"]],
nthread = 1
)

set.seed(8497)
Expand Down

0 comments on commit 1a04ccc

Please sign in to comment.