Skip to content

Commit

Permalink
Fix contrasts issue #283 and minor other changes (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 authored Mar 3, 2022
1 parent 343ee37 commit 4b213ad
Show file tree
Hide file tree
Showing 19 changed files with 100 additions and 84 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
* Argument `fit` of `init_refmodel()`'s argument `proj_predfun` was renamed to `fits`. This is a non-breaking change since all calls to `proj_predfun` in **projpred** have that argument unnamed. However, this cannot be guaranteed in the future, so we strongly encourage users with a custom `proj_predfun` to rename argument `fit` to `fits`. (GitHub: #263)
* `init_refmodel()` has gained argument `cvrefbuilder` which may be a custom function for constructing the K reference models in a K-fold CV. (GitHub: #271)
* Allow arguments to be passed from `project()`, `varsel()`, and `cv_varsel()` to the divergence minimizer. (GitHub: #278)
* In `init_refmodel()`, any `contrasts` attributes of the dataset's columns are silently removed. (GitHub: #284)

## Bug fixes

Expand Down
2 changes: 1 addition & 1 deletion R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#'
#' Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC. *Statistics and
#' Computing*, **27**(5), 1413-1432. DOI: \doi{10.1007/s11222-016-9696-4}.
#' Computing*, **27**(5), 1413-1432. \doi{10.1007/s11222-016-9696-4}.
#'
#' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2021). Pareto
#' smoothed importance sampling. *arXiv:1507.02646*. URL:
Expand Down
2 changes: 1 addition & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#' group).}
#' }
#' @references Gelman, A. and Hill, J. (2006). *Data Analysis Using Regression
#' and Multilevel/Hierarchical Models*. Cambridge University Press. DOI:
#' and Multilevel/Hierarchical Models*. Cambridge University Press.
#' \doi{10.1017/CBO9780511790942}.
#' @source <http://www.stat.columbia.edu/~gelman/arm/examples/mesquite/mesquite.dat>
"mesquite"
101 changes: 52 additions & 49 deletions R/divergence_minimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ divmin <- function(formula, projpred_var, ...) {
}
}

# Use projpred's own implementation to fit non-multilevel non-additive
# submodels:
fit_glm_ridge_callback <- function(formula, data,
projpred_var = matrix(nrow = nrow(data)),
projpred_regul = 1e-4, ...) {
# Preparations:
fr <- model.frame(delete.intercept(formula), data = data)
contrasts_arg <- get_contrasts_arg_list(formula, data = data)
x <- model.matrix(fr, data = data, contrasts.arg = contrasts_arg)
Expand All @@ -98,10 +101,12 @@ fit_glm_ridge_callback <- function(formula, data,
names(dot_args),
methods::formalArgs(glm_ridge)
)]
# Call the submodel fitter:
fit <- do.call(glm_ridge, c(
list(x = x, y = y, lambda = projpred_regul, obsvar = projpred_var),
dot_args
))
# Post-processing:
rownames(fit$beta) <- colnames(x)
sub <- nlist(
alpha = fit$beta0,
Expand All @@ -118,37 +123,34 @@ fit_glm_ridge_callback <- function(formula, data,
# `projpred.glm_fitter`):
fit_glm_callback <- function(formula, family, projpred_var, projpred_regul,
...) {
tryCatch({
if (family$family == "gaussian" && family$link == "identity") {
# Exclude arguments from `...` which cannot be passed to stats::lm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(stats::lm),
union(methods::formalArgs(stats::lm.fit),
methods::formalArgs(stats::lm.wfit)))
)]
return(suppressMessages(suppressWarnings(do.call(stats::lm, c(
list(formula = formula),
dot_args
)))))
} else {
# Exclude arguments from `...` which cannot be passed to stats::glm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(stats::glm),
methods::formalArgs(stats::glm.control))
)]
return(suppressMessages(suppressWarnings(do.call(stats::glm, c(
list(formula = formula, family = family),
dot_args
)))))
}
}, error = function(e) {
# May be used to handle errors.
stop(e)
})
if (family$family == "gaussian" && family$link == "identity") {
# Exclude arguments from `...` which cannot be passed to stats::lm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
c(methods::formalArgs(stats::lm),
methods::formalArgs(stats::lm.fit),
methods::formalArgs(stats::lm.wfit))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(stats::lm, c(
list(formula = formula),
dot_args
)))))
} else {
# Exclude arguments from `...` which cannot be passed to stats::glm():
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
c(methods::formalArgs(stats::glm),
methods::formalArgs(stats::glm.control))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(stats::glm, c(
list(formula = formula, family = family),
dot_args
)))))
}
}

# Use package "mgcv" to fit additive non-multilevel submodels:
Expand All @@ -158,9 +160,10 @@ fit_gam_callback <- function(formula, ...) {
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(methods::formalArgs(gam),
methods::formalArgs(mgcv::gam.fit))
c(methods::formalArgs(gam),
methods::formalArgs(mgcv::gam.fit))
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(gam, c(
list(formula = formula),
dot_args
Expand All @@ -176,10 +179,11 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random,
dot_args <- list(...)
dot_args <- dot_args[intersect(
names(dot_args),
union(union(methods::formalArgs(gamm4),
methods::formalArgs(lme4::lFormula)),
methods::formalArgs(lme4::glFormula))
c(methods::formalArgs(gamm4),
methods::formalArgs(lme4::lFormula),
methods::formalArgs(lme4::glFormula))
)]
# Call the submodel fitter:
fit <- tryCatch({
suppressMessages(suppressWarnings(do.call(gamm4, c(
list(formula = projpred_formula_no_random, random = projpred_random,
Expand Down Expand Up @@ -211,9 +215,7 @@ fit_gamm_callback <- function(formula, projpred_formula_no_random,
return(fit)
}

# Use package "lme4" to fit submodels for multilevel reference models (with a
# fallback to "projpred"'s own implementation for fitting non-multilevel (and
# non-additive) submodels):
# Use package "lme4" to fit multilevel submodels:
fit_glmer_callback <- function(formula, family,
control = control_callback(family), ...) {
tryCatch({
Expand All @@ -224,6 +226,7 @@ fit_glmer_callback <- function(formula, family,
names(dot_args),
methods::formalArgs(lme4::lmer)
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(lme4::lmer, c(
list(formula = formula, control = control),
dot_args
Expand All @@ -235,6 +238,7 @@ fit_glmer_callback <- function(formula, family,
names(dot_args),
methods::formalArgs(lme4::glmer)
)]
# Call the submodel fitter:
return(suppressMessages(suppressWarnings(do.call(lme4::glmer, c(
list(formula = formula, family = family,
control = control),
Expand Down Expand Up @@ -413,20 +417,19 @@ check_conv <- function(fit) {
# Prediction functions for submodels --------------------------------------

subprd <- function(fits, newdata) {
return(do.call(cbind, lapply(fits, function(fit) {
# Only pass argument `allow.new.levels` to the predict() generic if the fit
# is multilevel:
has_grp <- inherits(fit, c("lmerMod", "glmerMod"))
has_add <- inherits(fit, c("gam", "gamm4"))
if (has_add && !is.null(newdata)) {
prd_list <- lapply(fits, function(fit) {
is_glmm <- inherits(fit, c("lmerMod", "glmerMod"))
is_gam_gamm <- inherits(fit, c("gam", "gamm4"))
if (is_gam_gamm && !is.null(newdata)) {
newdata <- cbind(`(Intercept)` = rep(1, NROW(newdata)), newdata)
}
if (!has_grp) {
return(predict(fit, newdata = newdata))
} else {
if (is_glmm) {
return(predict(fit, newdata = newdata, allow.new.levels = TRUE))
} else {
return(predict(fit, newdata = newdata))
}
})))
})
return(do.call(cbind, prd_list))
}

## FIXME: find a way that allows us to remove this
Expand Down
1 change: 1 addition & 0 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ print.vselsummary <- function(x, digits = 1, ...) {
cat(paste0("Suggested Projection Size: ", x$suggested_size, "\n"))
cat("\n")
cat("Selection Summary:\n")
where <- "tidyselect" %:::% "where"
print(
x$selection %>% dplyr::mutate(dplyr::across(
where(is.numeric),
Expand Down
10 changes: 5 additions & 5 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -430,18 +430,18 @@ deparse_combine <- function(x, max_char = NULL) {
#' @export
magrittr::`%>%`

# `R CMD check` throws a note when using <package>:::<function>() (for accessing
# <function> which is not exported by its <package>). Of course, usage of
# non-exported functions should be avoided, but sometimes there's no way around
# that. Thus, with the following helper operator, it is possible to redefine
# such functions here in projpred:
`%:::%` <- function(pkg, fun) {
# Note: `utils::getFromNamespace(fun, pkg)` could probably be used, too (but
# its documentation is unclear about the inheritance from parent
# environments).
get(fun, envir = asNamespace(pkg), inherits = FALSE)
}

# Function where() is not exported by package tidyselect, so redefine it here to
# avoid a note in R CMD check which would occur for usage of
# tidyselect:::where():
where <- "tidyselect" %:::% "where"

# Helper function to combine separate `list`s into a single `list`:
rbind2list <- function(x) {
as.list(do.call(rbind, lapply(x, as.data.frame)))
Expand Down
6 changes: 3 additions & 3 deletions R/projpred-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@
#'
#' Dupuis, J. A. and Robert, C. P. (2003). Variable selection in qualitative
#' models via an entropic explanatory power. *Journal of Statistical Planning
#' and Inference*, **111**(1-2):77–94. DOI: \doi{10.1016/S0378-3758(02)00286-0}.
#' and Inference*, **111**(1-2):77–94. \doi{10.1016/S0378-3758(02)00286-0}.
#'
#' Piironen, J. and Vehtari, A. (2017). Comparison of Bayesian predictive
#' methods for model selection. *Statistics and Computing*, **27**(3):711-735.
#' DOI: \doi{10.1007/s11222-016-9649-y}.
#' \doi{10.1007/s11222-016-9649-y}.
#'
#' Piironen, J., Paasiniemi, M., and Vehtari, A. (2020). Projective inference in
#' high-dimensional problems: Prediction and feature selection. *Electronic
#' Journal of Statistics*, **14**(1):2155-2197. DOI: \doi{10.1214/20-EJS1711}.
#' Journal of Statistics*, **14**(1):2155-2197. \doi{10.1214/20-EJS1711}.
#'
#' Catalina, A., Bürkner, P.-C., and Vehtari, A. (2020). Projection predictive
#' inference for generalized linear and additive multilevel models.
Expand Down
12 changes: 11 additions & 1 deletion R/refmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
#' additionally to the properties required for [init_refmodel()]. For
#' non-default methods of [get_refmodel()], an object of the corresponding
#' class.
#' @param data Data used for fitting the reference model.
#' @param data Data used for fitting the reference model. Any `contrasts`
#' attributes of the dataset's columns are silently removed.
#' @param formula Reference model's formula. For general information on formulas
#' in \R, see [`formula`]. For multilevel formulas, see also package
#' \pkg{lme4} (in particular, functions [lme4::lmer()] and [lme4::glmer()]).
Expand Down Expand Up @@ -664,6 +665,15 @@ init_refmodel <- function(object, data, formula, family, ref_predfun = NULL,
offset <- rep(0, NROW(y))
}

# For avoiding the warning "contrasts dropped from factor <factor_name>" when
# predicting for each projected draw, e.g., for submodels fit with lm()/glm():
has_contr <- sapply(data, function(data_col) {
!is.null(attr(data_col, "contrasts"))
})
for (idx_col in which(has_contr)) {
attr(data[[idx_col]], "contrasts") <- NULL
}

# Functions ---------------------------------------------------------------

if (proper_model) {
Expand Down
2 changes: 1 addition & 1 deletion man/cv_varsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mesquite.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/projpred-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/refmodel-init-get.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_as_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ test_that("as.matrix.projection() works", {
print(tstsetup)
print(rlang::hash(m)) # cat(m)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -276,6 +276,6 @@ if (run_snaps) {
}
})

options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
8 changes: 4 additions & 4 deletions tests/testthat/test_datafit.R
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ test_that(paste(
print(tstsetup)
print(rlang::hash(pl_with_args))
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -476,7 +476,7 @@ test_that(paste(
print(tstsetup)
print(rlang::hash(pp_with_args))
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -538,7 +538,7 @@ test_that(paste(
print(tstsetup)
print(smmry, digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down Expand Up @@ -577,7 +577,7 @@ test_that(paste(
print(tstsetup)
print(smmry, digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_methods_vsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ test_that("`x` of class \"vselsummary\" (based on varsel()) works", {
print(tstsetup)
print(smmrys_vs[[tstsetup]], digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand All @@ -141,7 +141,7 @@ test_that("`x` of class \"vselsummary\" (based on cv_varsel()) works", {
print(tstsetup)
print(smmrys_cvvs[[tstsetup]], digits = 6)
})
options(width = width_orig$width)
options(width_orig)
if (testthat_ed_max2) local_edition(2)
}
}
Expand Down
Loading

0 comments on commit 4b213ad

Please sign in to comment.