Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix population scaling with other_keys (+ allow single-quantile-level predictions) #418

Closed
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.2
Version: 0.1.3
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
Expand Down Expand Up @@ -283,6 +284,7 @@ importFrom(hardhat,extract_recipe)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(magrittr,extract2)
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
Expand Down
5 changes: 2 additions & 3 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ autoplot.epi_workflow <- function(
if (!is.null(shift)) {
edf <- mutate(edf, time_value = time_value + shift)
}
extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
if (length(extra_keys) == 0L) extra_keys <- NULL
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
other_keys = extra_keys %||% character()
other_keys = other_keys
)
if (is.null(predictions)) {
return(autoplot(
Expand Down
2 changes: 2 additions & 0 deletions R/epipredict-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#' @importFrom cli cli_abort cli_warn
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
#' @importFrom dplyr full_join relocate summarise everything
#' @importFrom dplyr inner_join
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
#' @importFrom magrittr extract2
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom stats poly predict lm residuals quantile
Expand Down
10 changes: 6 additions & 4 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#' @export
key_colnames.recipe <- function(x, ...) {
key_colnames.recipe <- function(x, ..., exclude = character()) {
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
keys <- x$var_info$variable[x$var_info$role %in% "key"]
c(geo_key, keys, time_key) %||% character(0L)
full_key <- c(geo_key, keys, time_key) %||% character(0L)
full_key[!full_key %in% exclude]
}

#' @export
key_colnames.epi_workflow <- function(x, ...) {
key_colnames.epi_workflow <- function(x, ..., exclude = character()) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
molded_names <- names(mold$extras$roles)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
c(geo_key, keys, time_key) %||% character(0L)
full_key <- c(geo_key, keys, time_key) %||% character(0L)
full_key[!full_key %in% exclude]
}

kill_time_value <- function(v) {
Expand Down
23 changes: 22 additions & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,25 @@ slather.layer_population_scaling <-
)
rlang::check_dots_empty()

if (is.null(object$by)) {
# Assume `layer_predict` has calculated the prediction keys and other
# layers don't change the prediction key colnames:
prediction_key_colnames <- names(components$keys)
lhs_potential_keys <- prediction_key_colnames
rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col))
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
suggested_min_keys <- kill_time_value(lhs_potential_keys)
if (!all(suggested_min_keys %in% object$by)) {
cli_warn(c(
"Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`",
"i" = "Defaulting to join by {object$by}",
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
">" = "Manually specify `by =` to silence"
), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys")
}
}

object$by <- object$by %||% intersect(
epi_keys_only(components$predictions),
colnames(select(object$df, !object$df_pop_col))
Expand All @@ -152,10 +171,12 @@ slather.layer_population_scaling <-
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions))

components$predictions <- left_join(
components$predictions <- inner_join(
components$predictions,
object$df,
by = object$by,
relationship = "many-to-one",
unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(across(
Expand Down
2 changes: 1 addition & 1 deletion R/make_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ make_quantile_reg <- function() {

# can't make a method because object is second
out <- switch(type,
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile
rqs = {
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
dist_quantiles(x, list(object$tau))
Expand Down
61 changes: 53 additions & 8 deletions R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,25 @@ step_population_scaling <-
suffix = "_scaled",
skip = FALSE,
id = rand_id("population_scaling")) {
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id)
arg_is_lgl(create_new, skip)
arg_is_chr(df_pop_col, suffix, id)
if (rlang::dots_n(...) == 0L) {
cli_abort(c(
"`...` must not be empty.",
">" = "Please provide one or more tidyselect expressions in `...`
specifying the columns to which scaling should be applied.",
">" = "If you really want to list `step_population_scaling` in your
recipe but not have it do anything, you can use a tidyselection
that selects zero variables, such as `c()`."
))
}
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id)
arg_is_chr(role, df_pop_col, suffix, id)
hardhat::validate_column_names(df, df_pop_col)
arg_is_chr(by, allow_null = TRUE)
arg_is_numeric(rate_rescaling)
if (rate_rescaling <= 0) {
cli_abort("`rate_rescaling` must be a positive number.")
}
arg_is_lgl(create_new, skip)

recipes::add_step(
recipe,
Expand Down Expand Up @@ -138,6 +150,41 @@ step_population_scaling_new <-

#' @export
prep.step_population_scaling <- function(x, training, info = NULL, ...) {
if (is.null(x$by)) {
rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col)
lhs_potential_keys <- info %>%
filter(role %in% c("geo_value", "key", "time_value")) %>%
extract2("variable") %>%
unique() # in case of weird var with multiple of above roles
if (length(lhs_potential_keys) == 0L) {
# We're working with a recipe and tibble, and *_role hasn't set up any of
# the above roles. Let's say any column could actually act as a key, and
# lean on `intersect` below to make this something reasonable.
lhs_potential_keys <- names(training)
}
suggested_min_keys <- info %>%
filter(role %in% c("geo_value", "key")) %>%
extract2("variable") %>%
unique()
# (0 suggested keys if we weren't given any epikeytime var info.)
x$by <- intersect(lhs_potential_keys, rhs_potential_keys)
if (length(x$by) == 0L) {
cli_stop(c(
"Couldn't guess a default for `by`",
">" = "Please rename columns in your population data to match those in your training data,
or manually specify `by =` in `step_population_scaling()`."
), class = "epipredict__step_population_scaling__default_by_no_intersection")
}
if (!all(suggested_min_keys %in% x$by)) {
cli_warn(c(
"Couldn't find {setdiff(suggested_min_keys, x$by)} in population `df`.",
"i" = "Defaulting to join by {x$by}.",
">" = "Double-check whether column names on the population `df` match those for your time series.",
">" = "Consider using population data with breakdowns by {suggested_min_keys}.",
">" = "Manually specify `by =` to silence."
), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys")
}
}
step_population_scaling_new(
terms = x$terms,
role = x$role,
Expand All @@ -156,10 +203,6 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {

#' @export
bake.step_population_scaling <- function(object, new_data, ...) {
object$by <- object$by %||% intersect(
epi_keys_only(new_data),
colnames(select(object$df, !object$df_pop_col))
)
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
hardhat::validate_column_names(new_data, joinby$x)
hardhat::validate_column_names(object$df, joinby$y)
Expand All @@ -177,7 +220,9 @@ bake.step_population_scaling <- function(object, new_data, ...) {
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(new_data))

left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>%
inner_join(new_data, object$df,
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
suffix = c("", ".df")) %>%
mutate(
across(
all_of(object$columns),
Expand Down
8 changes: 4 additions & 4 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ grab_forged_keys <- function(forged, workflow, new_data) {
# 2. these are the keys in the training data
old_keys <- key_colnames(workflow)
# 3. these are the keys in the test data as input
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
new_df_keys <- key_colnames(new_data, other_keys = setdiff(new_keys, c("geo_value", "time_value")))
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
cli_warn(paste(
"Not all epi keys that were present in the training data are available",
Expand All @@ -49,10 +49,10 @@ grab_forged_keys <- function(forged, workflow, new_data) {
}
if (is_epi_df(new_data)) {
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys)
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys)
}
extras
}
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/test-key_colnames.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ test_that("key_colnames extracts time_value and geo_value, but not raw", {
fit(data = covid_case_death_rates)

expect_identical(key_colnames(my_workflow), c("geo_value", "time_value"))

# `exclude =` works:
expect_identical(key_colnames(my_workflow, exclude = "geo_value"), c("time_value"))
})

test_that("key_colnames extracts additional keys when they are present", {
Expand Down Expand Up @@ -49,4 +52,7 @@ test_that("key_colnames extracts additional keys when they are present", {

# order of the additional keys may be different
expect_equal(key_colnames(my_workflow), c("geo_value", "state", "pol", "time_value"))

# `exclude =` works:
expect_equal(key_colnames(my_workflow, exclude = c("time_value", "pol")), c("geo_value", "state"))
})
Loading
Loading