Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

importance weights not compatible with DALEXtra::model_profile #242

Closed
jamesgrecian opened this issue May 30, 2023 · 2 comments
Closed

importance weights not compatible with DALEXtra::model_profile #242

jamesgrecian opened this issue May 30, 2023 · 2 comments

Comments

@jamesgrecian
Copy link

I've added importance weights to a logistic regression using hardhat::importance_weights. However, when I try to generate partial dependence plots for the regression using DALEXtra::model_profile it returns an error.

Typically when we generate predictions from a fitted model we don't need to use weights, so I'm not sure why this is returning an error. Unless it's simply that DALEXtra doesn't know how to deal with a column formatted as <importance_weights>?

Here's a reprex with a dummy dataset. I'm trying to extract partial dependence profiles for each fold of the dataset to visually validate the model fit as discussed here #tidymodels/planning/issues/26. The issue may be complicated as I'm calculating the weights on the fly, depending on the number of points assigned to each fold as discussed here #240.

set.seed(1107)

# packages
library(sf)
#> Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(tidymodels)
library(spatialsample)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.3).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain
#> Anaconda not found on your computer. Conda related functionality such as create_env.R and condaenv and yml parameters from explain_scikitlearn will not be available

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = factor(as.numeric(lslpts)-1)) |>
  # Creating a dummy case weights column, to get past initial verification by recipe
  mutate(cwts = hardhat::importance_weights(NA))

# set up case weights as a recipe step
lsl_recipe <- recipes::recipe(
  lslpts ~ slope + cplan + cprof + elev + log10_carea, 
  data = sf::st_drop_geometry(lsl)
) |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(
      ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))
    ),
    # Need to set the "case_weights" role explicitly:
    role = "case_weights"
  )

# split into folds
lsl_folds <- spatial_block_cv(lsl, method = "random", v = 10)

# try GLM
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

# Using weights instead: no add_formula, because the formula is in our recipe
glm_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(glm_model) |> 
  add_case_weights(cwts)

# fit model to one fold of the data
glm_fold_fit <- glm_wflow_wts |> fit(lsl_folds$splits[[1]] |> analysis())

# generate partial dependence profile for model
# ideally want to generate profile for each fold to verify model fit
glm_explainer <- explain_tidymodels(glm_fold_fit,
                                    data = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      dplyr::select(slope, cplan, cprof, elev, log10_carea),
                                    y = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      pull(lslpts)) |>
  model_profile(N = 100, type = "partial")
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  308  rows  5  cols 
#>   -> target variable   :  308  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
#>   -> model_info        :  Model info detected classification task but 'y' is a factor .  (  WARNING  )
#>   -> model_info        :  By deafult classification tasks supports only numercical 'y' parameter. 
#>   -> model_info        :  Consider changing to numerical vector with 0 and 1 values.
#>   -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
#>   -> predicted values  :  the predict_function returns an error when executed (  WARNING  ) 
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
#>   A new explainer has been created!
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cwts = hardhat::importance_weights(...)`.
#> Caused by error in `ifelse()`:
#> ! object 'lslpts' not found
#> Backtrace:
#>      ▆
#>   1. ├─DALEX::model_profile(...)
#>   2. │ ├─ingredients::ceteris_paribus(...)
#>   3. │ └─ingredients:::ceteris_paribus.explainer(...)
#>   4. │   └─ingredients:::ceteris_paribus.default(...)
#>   5. │     ├─ingredients:::calculate_variable_profile(...)
#>   6. │     └─ingredients:::calculate_variable_profile.default(...)
#>   7. │       └─base::lapply(...)
#>   8. │         └─ingredients (local) FUN(X[[i]], ...)
#>   9. │           ├─DALEX (local) predict_function(model, new_data, ...)
#>  10. │           └─DALEXtra:::yhat.workflow(model, new_data, ...)
#>  11. │             ├─base::as.matrix(predict(X.model, newdata, type = "prob"))
#>  12. │             ├─stats::predict(X.model, newdata, type = "prob")
#>  13. │             └─workflows:::predict.workflow(X.model, newdata, type = "prob")
#>  14. │               └─workflows:::forge_predictors(new_data, workflow)
#>  15. │                 ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#>  16. │                 └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#>  17. │                   ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#>  18. │                   └─hardhat:::run_forge.default_recipe_blueprint(...)
#>  19. │                     └─hardhat:::forge_recipe_default_process(...)
#>  20. │                       ├─recipes::bake(object = rec, new_data = new_data)
#>  21. │                       └─recipes:::bake.recipe(object = rec, new_data = new_data)
#>  22. │                         ├─recipes::bake(step, new_data = new_data)
#>  23. │                         └─recipes:::bake.step_mutate(step, new_data = new_data)
#>  24. │                           ├─dplyr::mutate(new_data, !!!object$inputs)
#>  25. │                           └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#>  26. │                             └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>  27. │                               ├─base::withCallingHandlers(...)
#>  28. │                               └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>  29. │                                 └─mask$eval_all_mutate(quo)
#>  30. │                                   └─dplyr (local) eval()
#>  31. ├─hardhat::importance_weights(...)
#>  32. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x")
#>  33. │   └─vctrs::vec_cast(x, to, ..., call = call)
#>  34. ├─base::ifelse(lslpts == 1, 1, sum(lslpts == 1)/sum(lslpts == 0))
#>  35. └─base::.handleSimpleError(...)
#>  36.   └─dplyr (local) h(simpleError(msg, call))
#>  37.     └─rlang::abort(message, class = error_class, parent = parent, call = error_call)

Created on 2023-05-30 with reprex v2.0.2

@jamesgrecian jamesgrecian changed the title hardhat::importance_weights break DALEXtra::model_profile importance weights not compatible with DALEXtra::model_profile May 30, 2023
@topepo
Copy link
Member

topepo commented Jan 31, 2024

Your formula lslpts ~ slope + cplan + cprof + elev + log10_carea has lslpts as the outcome and step_mutate() is using it to construct the case weights.

tidymodels enforces the constraint that the outcome should not be used (in any way) when making predictions. Even if that column is available as prediction-time. This is to eliminate information leakage. It specifically excludes the outcome column(s) during prediction.

I would try using skip = TRUE so that the step does not execute outside of processing the training set.

@topepo
Copy link
Member

topepo commented Jan 31, 2024

Here's a smaller reprex:

library(tidymodels)
tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
mtcar_wts <- 
  mtcars %>% 
  mutate(case_wts = hardhat::importance_weights(NA))

car_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), role = "case_weights")

lm_fit <- 
  car_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 
  fit(mtcar_wts)

lm_fit %>% 
  extract_fit_engine() %>% 
  coef()
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007


predict(lm_fit, mtcar_wts[1:3,])
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `case_wts = hardhat::importance_weights(1/mpg)`.
#> Caused by error in `FUN()`:
#> ! non-numeric argument to binary operator
#> Backtrace:
#>      ▆
#>   1. ├─stats::predict(lm_fit, mtcar_wts[1:3, ])
#>   2. ├─workflows:::predict.workflow(lm_fit, mtcar_wts[1:3, ])
#>   3. │ └─workflows:::forge_predictors(new_data, workflow) at workflows/R/predict.R:63:3
#>   4. │   ├─hardhat::forge(new_data, blueprint = mold$blueprint) at workflows/R/predict.R:70:3
#>   5. │   └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint) at hardhat/R/forge.R:68:3
#>   6. │     ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes) at hardhat/R/forge.R:81:3
#>   7. │     └─hardhat:::run_forge.default_recipe_blueprint(...) at hardhat/R/forge.R:135:3
#>   8. │       └─hardhat:::forge_recipe_default_process(...) at hardhat/R/blueprint-recipe-default.R:350:3
#>   9. │         ├─recipes::bake(object = rec, new_data = new_data) at hardhat/R/blueprint-recipe-default.R:435:3
#>  10. │         └─recipes:::bake.recipe(object = rec, new_data = new_data)
#>  11. │           ├─recipes::bake(step, new_data = new_data)
#>  12. │           └─recipes:::bake.step_mutate(step, new_data = new_data)
#>  13. │             ├─dplyr::mutate(new_data, !!!object$inputs)
#>  14. │             └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#>  15. │               └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>  16. │                 ├─base::withCallingHandlers(...)
#>  17. │                 └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>  18. │                   └─mask$eval_all_mutate(quo)
#>  19. │                     └─dplyr (local) eval()
#>  20. ├─hardhat::importance_weights(1/mpg)
#>  21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x") at hardhat/R/case-weights.R:31:3
#>  22. │   └─vctrs::vec_cast(x, to, ..., call = call) at hardhat/R/util.R:245:3
#>  23. ├─base::Ops.data.frame(1, mpg) at hardhat/R/case-weights.R:31:3
#>  24. │ └─base::eval(f)
#>  25. │   └─base::eval(f)
#>  26. └─base::.handleSimpleError(...)
#>  27.   └─dplyr (local) h(simpleError(msg, call))
#>  28.     └─rlang::abort(message, class = error_class, parent = parent, call = error_call)
car_skip_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), 
              role = "case_weights", skip = TRUE)

lm_skip_fit <- 
  car_skip_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 
  fit(mtcar_wts)

lm_skip_fit %>% 
  extract_fit_engine() %>% 
  coef()
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007

predict(lm_skip_fit, mtcar_wts[1:3,])
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1  22.7
#> 2  22.0
#> 3  24.5

Created on 2024-01-31 with reprex v2.0.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants