-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
importance weights not compatible with DALEXtra::model_profile #242
Comments
Your formula tidymodels enforces the constraint that the outcome should not be used (in any way) when making predictions. Even if that column is available as prediction-time. This is to eliminate information leakage. It specifically excludes the outcome column(s) during prediction. I would try using |
Here's a smaller reprex: library(tidymodels) tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf) mtcar_wts <-
mtcars %>%
mutate(case_wts = hardhat::importance_weights(NA))
car_rec <-
recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>%
step_mutate(case_wts = hardhat::importance_weights(1 / mpg), role = "case_weights")
lm_fit <-
car_rec %>%
workflow(linear_reg()) %>%
add_case_weights(case_wts) %>%
fit(mtcar_wts)
lm_fit %>%
extract_fit_engine() %>%
coef()
#> (Intercept) wt disp gear
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007
predict(lm_fit, mtcar_wts[1:3,])
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `case_wts = hardhat::importance_weights(1/mpg)`.
#> Caused by error in `FUN()`:
#> ! non-numeric argument to binary operator
#> Backtrace:
#> ▆
#> 1. ├─stats::predict(lm_fit, mtcar_wts[1:3, ])
#> 2. ├─workflows:::predict.workflow(lm_fit, mtcar_wts[1:3, ])
#> 3. │ └─workflows:::forge_predictors(new_data, workflow) at workflows/R/predict.R:63:3
#> 4. │ ├─hardhat::forge(new_data, blueprint = mold$blueprint) at workflows/R/predict.R:70:3
#> 5. │ └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint) at hardhat/R/forge.R:68:3
#> 6. │ ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes) at hardhat/R/forge.R:81:3
#> 7. │ └─hardhat:::run_forge.default_recipe_blueprint(...) at hardhat/R/forge.R:135:3
#> 8. │ └─hardhat:::forge_recipe_default_process(...) at hardhat/R/blueprint-recipe-default.R:350:3
#> 9. │ ├─recipes::bake(object = rec, new_data = new_data) at hardhat/R/blueprint-recipe-default.R:435:3
#> 10. │ └─recipes:::bake.recipe(object = rec, new_data = new_data)
#> 11. │ ├─recipes::bake(step, new_data = new_data)
#> 12. │ └─recipes:::bake.step_mutate(step, new_data = new_data)
#> 13. │ ├─dplyr::mutate(new_data, !!!object$inputs)
#> 14. │ └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#> 15. │ └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#> 16. │ ├─base::withCallingHandlers(...)
#> 17. │ └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#> 18. │ └─mask$eval_all_mutate(quo)
#> 19. │ └─dplyr (local) eval()
#> 20. ├─hardhat::importance_weights(1/mpg)
#> 21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x") at hardhat/R/case-weights.R:31:3
#> 22. │ └─vctrs::vec_cast(x, to, ..., call = call) at hardhat/R/util.R:245:3
#> 23. ├─base::Ops.data.frame(1, mpg) at hardhat/R/case-weights.R:31:3
#> 24. │ └─base::eval(f)
#> 25. │ └─base::eval(f)
#> 26. └─base::.handleSimpleError(...)
#> 27. └─dplyr (local) h(simpleError(msg, call))
#> 28. └─rlang::abort(message, class = error_class, parent = parent, call = error_call) car_skip_rec <-
recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>%
step_mutate(case_wts = hardhat::importance_weights(1 / mpg),
role = "case_weights", skip = TRUE)
lm_skip_fit <-
car_skip_rec %>%
workflow(linear_reg()) %>%
add_case_weights(case_wts) %>%
fit(mtcar_wts)
lm_skip_fit %>%
extract_fit_engine() %>%
coef()
#> (Intercept) wt disp gear
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007
predict(lm_skip_fit, mtcar_wts[1:3,])
#> # A tibble: 3 × 1
#> .pred
#> <dbl>
#> 1 22.7
#> 2 22.0
#> 3 24.5 Created on 2024-01-31 with reprex v2.0.2 |
I've added importance weights to a logistic regression using
hardhat::importance_weights
. However, when I try to generate partial dependence plots for the regression usingDALEXtra::model_profile
it returns an error.Typically when we generate predictions from a fitted model we don't need to use weights, so I'm not sure why this is returning an error. Unless it's simply that DALEXtra doesn't know how to deal with a column formatted as
<importance_weights>
?Here's a reprex with a dummy dataset. I'm trying to extract partial dependence profiles for each fold of the dataset to visually validate the model fit as discussed here #tidymodels/planning/issues/26. The issue may be complicated as I'm calculating the weights on the fly, depending on the number of points assigned to each fold as discussed here #240.
Created on 2023-05-30 with reprex v2.0.2
The text was updated successfully, but these errors were encountered: