Skip to content

Commit

Permalink
Doc. for fit.epi_workflow and trying different things for mtv
Browse files Browse the repository at this point in the history
  • Loading branch information
rachlobay committed Jul 23, 2023
1 parent 47d7518 commit 06c71d9
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 2 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ S3method(extract_layers,workflow)
S3method(extrapolate_quantiles,dist_default)
S3method(extrapolate_quantiles,dist_quantiles)
S3method(extrapolate_quantiles,distribution)
S3method(fit,epi_workflow)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
Expand Down
2 changes: 2 additions & 0 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ epi_recipe.epi_df <-
term_info = var_info,
steps = NULL,
template = x[1,],
mtv = max(x$time_value), #%%
levels = NULL,
retained = NA
)
Expand Down Expand Up @@ -374,6 +375,7 @@ prep.epi_recipe <- function(
} else {
x$template <- training[0, ]
}
x$mtv <- max(training$time_value) #%%
x$tr_info <- tr_data
x$levels <- lvls
x$orig_lvls <- orig_lvls
Expand Down
43 changes: 43 additions & 0 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,47 @@ is_epi_workflow <- function(x) {
inherits(x, "epi_workflow")
}

#' Fit an `epi_workflow` object
#'
#' @description
#' This is the `fit()` method for an `epi_workflow` object that
#' estimates parameters for a given model from a set of data.
#' Fitting an `epi_workflow` involves two main steps, which are
#' preprocessing the data and fitting the underlying parsnip model.
#'
#' @inheritParams workflows::fit.workflow
#'
#' @param object an `epi_workflow` object
#'
#' @param x an `epi_df` of predictors and outcomes to use when
#' fitting the `epi_workflow`
#'
#' @return The `epi_workflow` object, updated with a fit parsnip
#' model in the `object$fit$fit` slot.
#'
#' @seealso workflows::fit-workflow
#'
#' @name fit-epi_workflow
#' @export
#' @examples
#' jhu <- case_death_rate_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7)
#'
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#' wf
#'
#' @export
fit.epi_workflow <- function(object, x, ...){

object$fit$meta <- list(mtv = max(x$time_value))
#object$fit$as_of <- attributes(x)$metadata$as_of

NextMethod()
}

#' Predict from an epi_workflow
#'
Expand Down Expand Up @@ -112,13 +153,15 @@ predict.epi_workflow <- function(object, new_data, ...) {
c("Can't predict on an untrained epi_workflow.",
i = "Do you need to call `fit()`?"))
}

components <- list()
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
blueprint = components$mold$blueprint)
components$keys <- grab_forged_keys(components$forged,
components$mold, new_data)
components <- apply_frosting(object, components, new_data, ...)

components$predictions
}

Expand Down
18 changes: 18 additions & 0 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ apply_frosting.default <- function(workflow, components, ...) {
apply_frosting.epi_workflow <-
function(workflow, components, new_data, ...) {

#%% wf1$post$meta <<- list(mtv = max(new_data$time_value)) #%% change wf1 and possibly <<-
# assign("workflow$post$meta", list(mtv = max(new_data$time_value)), envir = .GlobalEnv)

the_fit <- workflows::extract_fit_parsnip(workflow)

if (!has_postprocessor(workflow)) {
Expand Down Expand Up @@ -268,6 +271,21 @@ apply_frosting.epi_workflow <-
return(components)
}

#%% change_workflow = function(x){
# assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
#}

#%% add_meta_post <- function(workflow, new_data){
#
# workflow$post$meta <- list(mtv = max(new_data$time_value))
#
# workflow
#}

#%% changeMe = function(x){
# assign(deparse(substitute(x)), "changed", env=.GlobalEnv)
#}

#' @export
print.frosting <- function(x, form_width = 30, ...) {
cli::cli_div(
Expand Down
4 changes: 2 additions & 2 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ layer_add_forecast_date_new <- function(forecast_date, id) {

#' @export
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {

wf <<- workflow
if (is.null(object$forecast_date)) {
max_time_value <- max(new_data$time_value)
max_time_value <- max(workflows::extract_preprocessor(wf)$mtv, wf$fit$meta$mtv, max(new_data$time_value))#wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value)
object$forecast_date <- max_time_value
}

Expand Down
42 changes: 42 additions & 0 deletions man/fit-epi_workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 06c71d9

Please sign in to comment.