Skip to content

Commit

Permalink
Fix layer_add_target date and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
rachlobay committed Jul 26, 2023
1 parent 106b2ad commit a42de5b
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 65 deletions.
4 changes: 2 additions & 2 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ epi_recipe.epi_df <-
term_info = var_info,
steps = NULL,
template = x[1,],
mtv = max(x$time_value), #%%
mtv = max(x$time_value),
levels = NULL,
retained = NA
)
Expand Down Expand Up @@ -375,7 +375,7 @@ prep.epi_recipe <- function(
} else {
x$template <- training[0, ]
}
x$mtv <- max(training$time_value) #%%
x$mtv <- max(training$time_value)
x$tr_info <- tr_data
x$levels <- lvls
x$orig_lvls <- orig_lvls
Expand Down
27 changes: 0 additions & 27 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ apply_frosting.default <- function(workflow, components, ...) {
apply_frosting.epi_workflow <-
function(workflow, components, new_data, ...) {

#%% wf1$post$meta <<- list(mtv = max(new_data$time_value)) #%% change wf1 and possibly <<-
# assign("workflow$post$meta", list(mtv = max(new_data$time_value)), envir = .GlobalEnv)

the_fit <- workflows::extract_fit_parsnip(workflow)

if (!has_postprocessor(workflow)) {
Expand Down Expand Up @@ -267,34 +264,10 @@ apply_frosting.epi_workflow <-
la <- layers[[l]]
components <- slather(la, components, workflow, new_data)
}
#%% mtv <- max(new_data$time_value)
#%% update_workflow_post(workflow, mtv)

return(components)
}

#%% #' @export
# update_workflow_post <- function(x, mtv) {
# substitute(x) <- "changed"
# #assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
# #workflow$post$meta <- list(mtv = max(new_data$time_value))
# }

#%% change_workflow = function(x){
# assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
#}

#%% add_meta_post <- function(workflow, new_data){
#
# workflow$post$meta <- list(mtv = max(new_data$time_value))
#
# workflow
#}

#%% changeMe = function(x){
# assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
#}

#' @export
print.frosting <- function(x, form_width = 30, ...) {
cli::cli_div(
Expand Down
10 changes: 5 additions & 5 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ layer_add_forecast_date_new <- function(forecast_date, id) {

#' @export
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
#%% wf <<- workflow
#%% comp <<- components

if (is.null(object$forecast_date)) {
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) #%% wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value)
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv,
workflow$fit$meta$mtv,
max(new_data$time_value))
object$forecast_date <- max_time_value
}
as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of
as_of_fit <- workflow$fit$meta$as_of
as_of_post <- attributes(new_data)$metadata$as_of

as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) #%% as.Date(attributes(components$keys)$metadata$as_of)
as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post))

# It would be nice to say that forecast_date is >= to the max of all of them.
if (object$forecast_date < as_of_date) {
cli_warn(
c("The forecast_date is less than the most ",
Expand Down
52 changes: 37 additions & 15 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
#'
#' @param frosting a `frosting` postprocessor
#' @param target_date The target date to add as a column to the `epi_df`.
#' By default, this is the maximum `time_value` (from the data used in
#' pre-processing, fitting the model, and postprocessing) plus `ahead`,
#' where `ahead` has been specified in preprocessing
#' (most likely in `step_epi_ahead`). The user may override this with a
#' date of their own (that will usually be in the form "yyyy-mm-dd").
#' By default, this is the forecast date plus `ahead` (from `step_epi_ahead`
#' in the `epi_recipe`) if there is a `layer_add_forecast_date` in the
#' `epi_workflow`. If there's no such layer, then the user may specify
#' their own target date with a date (of the form "yyyy-mm-dd").
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
Expand All @@ -28,24 +27,35 @@
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#' latest <- get_test_data(r, jhu)
#'
#' # Use ahead from preprocessing
#' # Use ahead + forecast date
#' f <- frosting() %>% layer_predict() %>%
#' layer_add_forecast_date(forecast_date = "2022-05-31") %>%
#' layer_add_target_date() %>%
#' layer_naomit(.pred)
#' wf1 <- wf %>% add_frosting(f)
#'
#' p <- predict(wf1, latest)
#' p
#'
#' # Override default behaviour by specifying own target date
#' f2 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' # Use ahead + max time value from pre, fit, post
#' # which is the same if include `layer_add_forecast_date()`
#' f <- frosting() %>% layer_predict() %>%
#' layer_add_target_date() %>%
#' layer_naomit(.pred)
#' wf2 <- wf %>% add_frosting(f2)
#'
#' p2 <- predict(wf2, latest)
#' p2
#'
#' # Specify own target date
#' f3 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' layer_naomit(.pred)
#' wf3 <- wf %>% add_frosting(f3)
#'
#' p3 <- predict(wf3, latest)
#' p3
layer_add_target_date <-
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {
target_date <- arg_to_date(target_date, allow_null = TRUE)
Expand All @@ -67,14 +77,26 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) {
slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) {

the_recipe <- workflows::extract_recipe(workflow)
the_frosting <- extract_frosting(workflow)

if (detect_layer(the_frosting, "layer_add_forecast_date") &&
!is.null(extract_argument(the_frosting,
"layer_add_forecast_date", "forecast_date"))) {
forecast_date <- extract_argument(the_frosting,
"layer_add_forecast_date", "forecast_date")

ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")

target_date = forecast_date + ahead

} else if (is.null(object$target_date) ||
detect_layer(the_frosting, "layer_add_forecast_date")) {
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv,
workflow$fit$meta$mtv,
max(new_data$time_value))

if (is.null(object$target_date)) {
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value))
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")

if (is.null(ahead)){
stop("`ahead` must be specified in preprocessing.")
}
target_date = max_time_value + ahead
} else{
target_date = as.Date(object$target_date)
Expand Down
30 changes: 20 additions & 10 deletions man/layer_add_target_date.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 61 additions & 6 deletions tests/testthat/test-layer_add_target_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
latest <- jhu %>%
dplyr::filter(time_value >= max(time_value) - 14)

test_that("Use ahead from preprocessing", {
test_that("Use ahead + max time value from pre, fit, post", {

f <- frosting() %>% layer_predict() %>%
layer_add_target_date() %>% layer_naomit(.pred)
f <- frosting() %>%
layer_predict() %>%
layer_add_target_date() %>%
layer_naomit(.pred)
wf1 <- wf %>% add_frosting(f)

expect_silent(p <- predict(wf1, latest))
Expand All @@ -21,12 +23,48 @@ test_that("Use ahead from preprocessing", {
expect_equal(nrow(p), 3L)
expect_equal(p$target_date, rep(as.Date("2022-01-07"), times = 3))
expect_named(p, c("geo_value", "time_value", ".pred", "target_date"))

# Should be same dates as above
f2 <- frosting() %>%
layer_predict() %>%
layer_add_forecast_date() %>%
layer_add_target_date() %>%
layer_naomit(.pred)
wf2 <- wf %>% add_frosting(f2)

expect_warning(p2 <- predict(wf2, latest))
expect_equal(ncol(p2), 5L)
expect_s3_class(p2, "epi_df")
expect_equal(nrow(p2), 3L)
expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3))
expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date"))

})

test_that("Override default behaviour and specify own target date", {
test_that("Use ahead + specified forecast date", {

f <- frosting() %>% layer_predict() %>%
layer_add_target_date(target_date = "2022-01-08") %>% layer_naomit(.pred)
f <- frosting() %>%
layer_predict() %>%
layer_add_forecast_date(forecast_date = "2022-05-31") %>%
layer_add_target_date() %>%
layer_naomit(.pred)
wf1 <- wf %>% add_frosting(f)

expect_silent(p <- predict(wf1, latest))
expect_equal(ncol(p), 5L)
expect_s3_class(p, "epi_df")
expect_equal(nrow(p), 3L)
expect_equal(p$target_date, rep(as.Date("2022-06-07"), times = 3))
expect_named(p, c("geo_value", "time_value", ".pred", "forecast_date", "target_date"))

})

test_that("Specify own target date", {

f <- frosting() %>%
layer_predict() %>%
layer_add_target_date(target_date = "2022-01-08") %>%
layer_naomit(.pred)
wf1 <- wf %>% add_frosting(f)

expect_silent(p2 <- predict(wf1, latest))
Expand All @@ -36,3 +74,20 @@ test_that("Override default behaviour and specify own target date", {
expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3))
expect_named(p2, c("geo_value", "time_value", ".pred", "target_date"))
})

test_that("Specify own target date, but have a forecast date layer", {

f <- frosting() %>%
layer_predict() %>%
layer_add_forecast_date() %>%
layer_add_target_date(target_date = "2022-01-08") %>%
layer_naomit(.pred)
wf1 <- wf %>% add_frosting(f)

expect_warning(p2 <- predict(wf1, latest))
expect_equal(ncol(p2), 5L)
expect_s3_class(p2, "epi_df")
expect_equal(nrow(p2), 3L)
expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3))
expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date"))
})

0 comments on commit a42de5b

Please sign in to comment.