Skip to content

Commit

Permalink
Update doc + code to pick max pre/fit/post
Browse files Browse the repository at this point in the history
  • Loading branch information
rachlobay committed Jul 25, 2023
1 parent 06c71d9 commit 106b2ad
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 26 deletions.
11 changes: 5 additions & 6 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ is_epi_workflow <- function(x) {
#'
#' @param object an `epi_workflow` object
#'
#' @param x an `epi_df` of predictors and outcomes to use when
#' @param data an `epi_df` of predictors and outcomes to use when
#' fitting the `epi_workflow`
#'
#' @param control A [workflows::control_workflow()] object
#'
#' @return The `epi_workflow` object, updated with a fit parsnip
#' model in the `object$fit$fit` slot.
#'
Expand All @@ -92,10 +94,9 @@ is_epi_workflow <- function(x) {
#' wf
#'
#' @export
fit.epi_workflow <- function(object, x, ...){
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()){

object$fit$meta <- list(mtv = max(x$time_value))
#object$fit$as_of <- attributes(x)$metadata$as_of
object$fit$meta <- list(mtv = max(data$time_value), as_of = attributes(data)$metadata$as_of)

NextMethod()
}
Expand Down Expand Up @@ -153,15 +154,13 @@ predict.epi_workflow <- function(object, new_data, ...) {
c("Can't predict on an untrained epi_workflow.",
i = "Do you need to call `fit()`?"))
}

components <- list()
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
blueprint = components$mold$blueprint)
components$keys <- grab_forged_keys(components$forged,
components$mold, new_data)
components <- apply_frosting(object, components, new_data, ...)

components$predictions
}

Expand Down
9 changes: 9 additions & 0 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,19 @@ apply_frosting.epi_workflow <-
la <- layers[[l]]
components <- slather(la, components, workflow, new_data)
}
#%% mtv <- max(new_data$time_value)
#%% update_workflow_post(workflow, mtv)

return(components)
}

#%% #' @export
# update_workflow_post <- function(x, mtv) {
# substitute(x) <- "changed"
# #assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
# #workflow$post$meta <- list(mtv = max(new_data$time_value))
# }

#%% change_workflow = function(x){
# assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
#}
Expand Down
22 changes: 14 additions & 8 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
#' @param forecast_date The forecast date to add as a column to the `epi_df`.
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note that
#' when the forecast date is left unspecified, it is set to the maximum time
#' value in the test data after any processing (ex. leads and lags) has been
#' applied.
#' value from the data used in pre-processing, fitting the model, and
#' postprocessing.
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
#'
#' @details To use this function, either specify a forecast date or leave the
#' forecast date unspecifed here. In the latter case, the forecast date will
#' be set as the maximum time value in the processed test data. In any case,
#' when the forecast date is less than the most recent update date of the data
#' (ie. the `as_of` value), an appropriate warning will be thrown.
#' be set as the maximum time value from the data used in pre-processing,
#' fitting the model, and postprocessing. In any case, when the forecast date is
#' less than the maximum `as_of` value (from the data used pre-processing,
#' model fitting, and postprocessing), an appropriate warning will be thrown.
#'
#' @export
#' @examples
Expand Down Expand Up @@ -82,14 +83,19 @@ layer_add_forecast_date_new <- function(forecast_date, id) {

#' @export
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
wf <<- workflow
#%% wf <<- workflow
#%% comp <<- components
if (is.null(object$forecast_date)) {
max_time_value <- max(workflows::extract_preprocessor(wf)$mtv, wf$fit$meta$mtv, max(new_data$time_value))#wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value)
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) #%% wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value)
object$forecast_date <- max_time_value
}
as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of
as_of_fit <- workflow$fit$meta$as_of
as_of_post <- attributes(new_data)$metadata$as_of

as_of_date <- as.Date(attributes(components$keys)$metadata$as_of)
as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) #%% as.Date(attributes(components$keys)$metadata$as_of)

# It would be nice to say that forecast_date is >= to the max of all of them.
if (object$forecast_date < as_of_date) {
cli_warn(
c("The forecast_date is less than the most ",
Expand Down
7 changes: 4 additions & 3 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
#'
#' @param frosting a `frosting` postprocessor
#' @param target_date The target date to add as a column to the `epi_df`.
#' By default, this is the maximum `time_value` from the processed test
#' data plus `ahead`, where `ahead` has been specified in preprocessing
#' By default, this is the maximum `time_value` (from the data used in
#' pre-processing, fitting the model, and postprocessing) plus `ahead`,
#' where `ahead` has been specified in preprocessing
#' (most likely in `step_epi_ahead`). The user may override this with a
#' date of their own (that will usually be in the form "yyyy-mm-dd").
#' @param id a random id string
Expand Down Expand Up @@ -68,7 +69,7 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data
the_recipe <- workflows::extract_recipe(workflow)

if (is.null(object$target_date)) {
max_time_value <- max(new_data$time_value)
max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value))
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")

if (is.null(ahead)){
Expand Down
6 changes: 4 additions & 2 deletions man/fit-epi_workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions man/layer_add_forecast_date.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/layer_add_target_date.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 106b2ad

Please sign in to comment.