Skip to content

Commit

Permalink
Merge pull request #27 from josephsdavid/additional-wrappers
Browse files Browse the repository at this point in the history
added combined simple forecasts and ensembling of models
  • Loading branch information
ngupta23 authored Apr 8, 2020
2 parents 0df0720 + 43bad43 commit fa68a25
Show file tree
Hide file tree
Showing 23 changed files with 2,282 additions and 107 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tswgewrapped
Title: Helpful wrappers for 'tswge', 'vars' and 'nnfor' time series packages
Version: 1.8.10.5
Version: 1.8.10.6
Authors@R: c(
person("David", "Josephs", email = "[email protected]", role = c("aut", "cre")),
person("Nikhil", "Gupta", email = "[email protected]", role = c("aut")))
Expand Down Expand Up @@ -31,6 +31,7 @@ Imports:
Rfast,
tibble,
tictoc,
tidyr (>= 1.0.0),
tswge,
vars
RoxygenNote: 7.1.0
Expand Down
214 changes: 156 additions & 58 deletions R/ModelCombine.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ModelCombine = R6::R6Class(
#' @return A new `ModelCombine` object.
initialize = function(data = NA, var_interest = NA, uni_models = NA, var_models = NA, mlp_models = NA, verbose = 0)
{
self$set_verbose(verbose = verbose)
private$set_data(data = data)
private$set_var_interest(var_interest = var_interest)
private$set_uni_compare_objects(models = uni_models)
Expand Down Expand Up @@ -51,7 +52,61 @@ ModelCombine = R6::R6Class(
set_verbose = function(verbose = 0){ private$verbose = verbose },

#### General Public Methods ----


#' @description Plots the simple forecast for each model
#' @param lastn If TRUE, this will plot the forecasts forthe last n.ahead values of the realization (Default: FALSE)
#' @param newxreg The future exogenous variable values to be used for prediction.
#' Applicable to models that require the values of the new exogenous variables to be provided for future forecasts, e.g. nnfor::mlp()
#' @param limits If TRUE, this will also plot the lower and upper limits of the forecasts (Default: FALSE)
#' @param zoom A number indicating how much to zoom into the plot.
#' For example zoom = 50 will only plot the last 50 points of the realization
#' Useful for cases where realizations that are long and n.ahead is small.
plot_simple_forecasts = function(lastn = FALSE, newxreg = NA, limits = FALSE, zoom = NA){

forecasts = data.frame()

mlp_compare_objects = private$get_mlp_compare_objects()
if (length(mlp_compare_objects) >= 1 & lastn == TRUE){
stop(paste0("Your '", self$classname, "' object has a ModelCompareNNforCaret object which does not support plotting simple forecasts with lastn = TRUE. Please make lastn = FALSE and rerun with xreg passed."))
}
for (i in seq_along(mlp_compare_objects)){
subset_results = mlp_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = limits, zoom = zoom, plot = FALSE)

filtered = subset_results$plot_data %>%
private$filter_best_caret_model(caret_compare_object = mlp_compare_objects[[i]])

forecasts = dplyr::bind_rows(forecasts, filtered)
}

uni_compare_objects = private$get_uni_compare_objects()
for (i in seq_along(uni_compare_objects)){
subset_results = uni_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = limits, zoom = zoom, plot = FALSE)
forecasts = dplyr::bind_rows(forecasts, subset_results$plot_data)
}

var_compare_objects = private$get_var_compare_objects()
for (i in seq_along(var_compare_objects)){
subset_results = var_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = limits, zoom = zoom, plot = FALSE)
forecasts = dplyr::bind_rows(forecasts, subset_results$plot_data)
}

p = ggplot2::ggplot() +
ggplot2::geom_line(forecasts %>% dplyr::filter(Model == "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 1) +
ggplot2::geom_line(forecasts %>% dplyr::filter(Model != "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 0.75) +
ggplot2::ylab("Simple Forecasts")

if (limits == TRUE){
p = p +
ggplot2::geom_line(forecasts, mapping = ggplot2::aes(x=Time, y=ll, color = Model), linetype = "dashed", size = 0.5) +
ggplot2::geom_line(forecasts, mapping = ggplot2::aes(x=Time, y=ul, color = Model), linetype = "dashed", size = 0.5)
}

print(p)

return(forecasts)

},

#' @description Plots the forecasts per batch for all models
#' @param only_sliding If TRUE, this will only plot the batch forecasts
#' for the models that used window ASE calculations
Expand Down Expand Up @@ -231,22 +286,110 @@ ModelCombine = R6::R6Class(
subset_results = mlp_compare_objects[[i]]$get_tabular_metrics(only_sliding = only_sliding, ases = ases) %>%
private$filter_best_caret_model(caret_compare_object = mlp_compare_objects[[i]])

# best_model_id = mlp_compare_objects[[i]]$get_best_model_id()
#
# subset_results = subset_results %>%
# dplyr::filter(Model == best_model_id)

results = rbind(results, subset_results)
}

return(results)
},

#' @description Computes the simple forecasts using all the models
#' @param lastn If TRUE, this will get the forecasts for the last n.ahead values of the realization (Default: FALSE).
#' If there is a ModelCompareNNforCaret object passed to this object, then lastn must be TRUE.
#' @param newxreg The future exogenous variable values to be used for prediction.
#' Applicable to models that require the values of the new exogenous variables to be provided for future forecasts, e.g. nnfor::mlp()
#' @return The forecasted values
compute_simple_forecasts = function(lastn = FALSE, newxreg = NA){

forecasts = data.frame()

mlp_compare_objects = private$get_mlp_compare_objects()
if (length(mlp_compare_objects) >= 1 & lastn == TRUE){
stop(paste0("Your '", self$classname, "' object has a ModelCompareNNforCaret object which does not support plotting simple forecasts with lastn = TRUE. Please make lastn = FALSE and rerun with xreg passed."))
}
for (i in seq_along(mlp_compare_objects)){
subset_results = mlp_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = FALSE, plot = FALSE)

filtered = subset_results$forecasts %>%
private$filter_best_caret_model(caret_compare_object = mlp_compare_objects[[i]])

forecasts = dplyr::bind_rows(forecasts, filtered)
}

uni_compare_objects = private$get_uni_compare_objects()
for (i in seq_along(uni_compare_objects)){
subset_results = uni_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = FALSE, plot = FALSE)
forecasts = dplyr::bind_rows(forecasts, subset_results$forecasts)
}

var_compare_objects = private$get_var_compare_objects()
for (i in seq_along(var_compare_objects)){
subset_results = var_compare_objects[[i]]$plot_simple_forecasts(lastn = lastn, newxreg = newxreg, limits = FALSE, plot = FALSE)
forecasts = dplyr::bind_rows(forecasts, subset_results$forecasts)
}

return(forecasts)
},

#' @description Creates an ensemble model based on all the models provided
create_ensemble = function(){
data_for_model = self$get_tabular_metrics(only_sliding = TRUE, ases = FALSE) %>%
dplyr::distinct() %>% # Remove duplicate entries for Model = 'Realization'
assertr::verify(assertr::has_all_names("Time", "Model", "f")) %>%
tidyr::pivot_wider(id_cols = Time, names_from = Model, values_from = f) %>%
stats::na.omit() %>%
dplyr::select(-Time)

print(str(data_for_model))

glm_ensemble = glm(formula = Realization ~ ., data = data_for_model)

if (private$get_verbose() >= 1){
print(summary(glm_ensemble))
}
private$set_ensemble_model(model = glm_ensemble)
},

#' @description Makes a prediction based on the ensemble model
#' @param naive If TRUE, the ensemble will be a simple mean of the prediction of all the models
#' If FALSE, the ensemble will use a glm model created from the batch predictions of all the models
#' @param comb If 'naive' = TRUE, how to combine the predictions. Allowed values are 'mean' or 'median'
#' @param newxreg The future exogenous variable values to be used for prediction.
#' Applicable to models that require the values of the new exogenous variables to be provided for future forecasts, e.g. nnfor::mlp()
#' @return The predictions from each model along with the ensemble prediction
predict_ensemble = function(naive = FALSE, comb = 'median', newxreg = NA){

if (naive == TRUE){
if (comb != 'median' & comb != 'mean'){
warning(paste0("You are using a naive model, but the value of comb is set to '", comb, "' . The allowed values are 'median' or 'mean'. This will be set to the default value of 'median'."))
comb = 'median'
}
}

forecasts = self$compute_simple_forecasts(lastn = FALSE, newxreg = newxreg) %>%
assertr::verify(assertr::has_all_names("Time", "Model", "f")) %>%
tidyr::pivot_wider(id_cols = Time, names_from = Model, values_from = f) %>%
stats::na.omit() %>%
dplyr::select(-Time)

if (naive == TRUE){
if (comb == 'mean'){
forecasts = forecasts %>%
dplyr::mutate(ensemble = rowMeans(.))
}
if (comb == 'median'){
forecasts = forecasts %>%
dplyr::mutate(ensemble = Rfast::rowMedians(as.matrix(.)))
}
}
else{
forecasts = forecasts %>%
dplyr::mutate(ensemble = stats::predict(private$get_ensemble_model(), newdata = forecasts)) %>%
dplyr::mutate_if(is.numeric, as.double) # Converts Named numeric (output of predict) to simple numeric
}

return(forecasts)

}

),


Expand All @@ -258,6 +401,7 @@ ModelCombine = R6::R6Class(
var_models = NA,
mlp_models = NA,
verbose = NA,
ensemble_model = NA,

set_data = function(data){
if (all(is.na(data))){ stop("You have not provided the time series data. Please provide to continue.") }
Expand Down Expand Up @@ -285,58 +429,12 @@ ModelCombine = R6::R6Class(
},
get_mlp_compare_objects = function(){return(private$mlp_models)},

get_len_x = function(){return(nrow(self$get_data()))},

compute_simple_forecasts = function(lastn){
## TODO: Needed for NNFOR
## But add an argument xreg
## Used by plot_simple_forecasts in the base class

message("This function is not supported for nnfor::mlp at this time.")

results = dplyr::tribble(~Model, ~Time, ~f, ~ll, ~ul)
#
# if (lastn == FALSE){
# data_start = 1
# data_end = private$get_len_x()
# train_data = self$get_data()[data_start:data_end, ]
#
# }
# else{
# data_start = 1
# data_end = private$get_len_x() - self$get_n.ahead()
# train_data = self$get_data()[data_start:data_end, ]
# }
#
# from = data_end + 1
# to = data_end + self$get_n.ahead()
#
# # Define Train Data
#
# for (name in names(private$get_models())){
#
# var_interest = self$get_var_interest()
# k = private$get_models()[[name]][['k_final']]
# trend_type = private$get_models()[[name]][['trend_type']]
#
# # Fit model for the batch
# varfit = vars::VAR(train_data, p=k, type=trend_type)
#
# # Forecast for the batch
# forecasts = stats::predict(varfit, n.ahead=self$get_n.ahead())
# forecasts = forecasts$fcst[[var_interest]] ## Get the forecasts only for the dependent variable
#
# results = results %>%
# dplyr::add_row(Model = name,
# Time = (from:to),
# f = forecasts[, 'fcst'],
# ll = forecasts[, 'lower'],
# ul = forecasts[, 'upper'])
#
# }
#
return(results)
set_ensemble_model = function(model){
private$ensemble_model = model
},
get_ensemble_model = function(){return(private$ensemble_model)},

get_len_x = function(){return(nrow(self$get_data()))},

filter_best_caret_model = function(data, caret_compare_object){
# Given a caret_compare_object and a dataframe 'data' that has a 'Model' column
Expand Down
30 changes: 17 additions & 13 deletions R/ModelCompareBase.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ ModelCompareBase = R6::R6Class(
#' @param zoom A number indicating how much to zoom into the plot.
#' For example zoom = 50 will only plot the last 50 points of the realization
#' Useful for cases where realizations that are long and n.ahead is small.
plot_simple_forecasts = function(lastn = FALSE, newxreg = NA, limits = FALSE, zoom = NA){
#' @param plot If FALSE the plots are not plotted; useful when you want to just return the data (Default = TRUE)
plot_simple_forecasts = function(lastn = FALSE, newxreg = NA, limits = FALSE, zoom = NA, plot = TRUE){

forecasts = private$compute_simple_forecasts_with_validation(lastn = lastn, newxreg = newxreg)

Expand All @@ -236,6 +237,7 @@ ModelCompareBase = R6::R6Class(
ll = self$get_data_var_interest(),
ul = self$get_data_var_interest())


if (!is.na(zoom)){
zoom = private$validate_zoom(zoom)

Expand All @@ -246,20 +248,22 @@ ModelCompareBase = R6::R6Class(
dplyr::filter(Time >= start)
}

p = ggplot2::ggplot() +
ggplot2::geom_line(results %>% dplyr::filter(Model == "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 1) +
ggplot2::geom_line(results %>% dplyr::filter(Model != "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 0.75) +
ggplot2::ylab("Simple Forecasts")

if (limits == TRUE){
p = p +
ggplot2::geom_line(results, mapping = ggplot2::aes(x=Time, y=ll, color = Model), linetype = "dashed", size = 0.5) +
ggplot2::geom_line(results, mapping = ggplot2::aes(x=Time, y=ul, color = Model), linetype = "dashed", size = 0.5)
if (plot == TRUE){
p = ggplot2::ggplot() +
ggplot2::geom_line(results %>% dplyr::filter(Model == "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 1) +
ggplot2::geom_line(results %>% dplyr::filter(Model != "Actual"), mapping = ggplot2::aes(x=Time, y=f, color = Model), size = 0.75) +
ggplot2::ylab("Simple Forecasts")

if (limits == TRUE){
p = p +
ggplot2::geom_line(results, mapping = ggplot2::aes(x=Time, y=ll, color = Model), linetype = "dashed", size = 0.5) +
ggplot2::geom_line(results, mapping = ggplot2::aes(x=Time, y=ul, color = Model), linetype = "dashed", size = 0.5)
}

print(p)
}

print(p)

return(forecasts)
return(list(forecasts = forecasts, plot_data = results))

},

Expand Down
2 changes: 1 addition & 1 deletion R/ModelCompareNNforCaret.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ ModelCompareNNforCaret = R6::R6Class(

compute_simple_forecasts = function(lastn, newxreg){
if (lastn == TRUE){
message("This class does not support lastn = TRUE since the model has already been built using the entire data. Hence, lastn will be set to FALSE.")
message(paste0("The '", self$classname, "' class does not support lastn = TRUE since the model has already been built using the entire data. Hence, lastn will be set to FALSE."))
lastn = FALSE
}

Expand Down
Binary file added build/tswgewrapped_1.8.10.6.tar.gz
Binary file not shown.
Binary file added inst/extdata/caret_model_train_bs120.rds
Binary file not shown.
File renamed without changes.
3 changes: 3 additions & 0 deletions inst/extdata/ensemble_glm_train_bs120.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"reps19_hd2_sdetFALSE","Univar A","Univar B","Univar C","AIC Both - R","AIC Trend - R","ensemble"
8.25875556986836,8.2597570927,8.2642322988,8.2642322988,8.26087898347003,8.26020162823867,8.2643919071676
8.26214895340459,8.26962681553,8.27937511312,8.27937511312,8.26549897076232,8.26297965565337,8.27499807022726
3 changes: 3 additions & 0 deletions inst/extdata/ensemble_naive_mean_train_bs120.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"reps19_hd2_sdetFALSE","Univar A","Univar B","Univar C","AIC Both - R","AIC Trend - R","ensemble"
8.25875556986836,8.2597570927,8.2642322988,8.2642322988,8.26087898347003,8.26020162823867,8.26134297864617
8.26214895340459,8.26962681553,8.27937511312,8.27937511312,8.26549897076232,8.26297965565337,8.26983410359838
3 changes: 3 additions & 0 deletions inst/extdata/ensemble_naive_median_train_bs120.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"reps19_hd2_sdetFALSE","Univar A","Univar B","Univar C","AIC Both - R","AIC Trend - R","ensemble"
8.25875556986836,8.2597570927,8.2642322988,8.2642322988,8.26087898347003,8.26020162823867,8.26054030585435
8.26214895340459,8.26962681553,8.27937511312,8.27937511312,8.26549897076232,8.26297965565337,8.26756289314616
42 changes: 42 additions & 0 deletions inst/extdata/mdl_combine_ases1_train_bs120.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"Model","ASE","Time_Test_Start","Time_Test_End","Batch"
"Univar A",3.12308225764145e-05,119,120,1
"Univar A",8.31434200888188e-05,121,122,2
"Univar A",8.99107613814644e-05,123,124,3
"Univar A",9.23963889730699e-05,125,126,4
"Univar A",3.58922274584229e-05,127,128,5
"Univar A",3.96032672843582e-05,129,130,6
"Univar A",5.75283877703727e-05,131,132,7
"Univar A",0.000179876397666713,133,134,8
"Univar B",3.11894455528081e-05,119,120,1
"Univar B",0.000195071292086915,121,122,2
"Univar B",2.0555112379594e-05,123,124,3
"Univar B",3.4537354499041e-06,125,126,4
"Univar B",0.000238009765613329,127,128,5
"Univar B",5.00922427881027e-05,129,130,6
"Univar B",3.61710997771645e-06,131,132,7
"Univar B",0.000133945482877433,133,134,8
"Univar C",0.000133945482877433,133,134,1
"AIC Both - R",9.53415765907678e-05,119,120,1
"AIC Both - R",5.63432845915735e-05,121,122,2
"AIC Both - R",0.000162816178046665,123,124,3
"AIC Both - R",5.09735054355053e-06,125,126,4
"AIC Both - R",0.000149402359459911,127,128,5
"AIC Both - R",0.000352126894297292,129,130,6
"AIC Both - R",0.00111179164980607,131,132,7
"AIC Both - R",9.41956795816891e-05,133,134,8
"AIC Trend - R",4.07323799036183e-06,119,120,1
"AIC Trend - R",0.000225427731940963,121,122,2
"AIC Trend - R",1.84818669129706e-05,123,124,3
"AIC Trend - R",9.46361043687706e-06,125,126,4
"AIC Trend - R",0.00019476432486588,127,128,5
"AIC Trend - R",0.00032344234699412,129,130,6
"AIC Trend - R",0.000689865028845393,131,132,7
"AIC Trend - R",2.47277635958661e-06,133,134,8
"reps19_hd2_sdetFALSE",4.01677184126864e-05,119,120,1
"reps19_hd2_sdetFALSE",0.000272858996316274,121,122,2
"reps19_hd2_sdetFALSE",0.000262383678444808,123,124,3
"reps19_hd2_sdetFALSE",1.96260233679372e-05,125,126,4
"reps19_hd2_sdetFALSE",5.52812013604825e-05,127,128,5
"reps19_hd2_sdetFALSE",5.52334056835546e-05,129,130,6
"reps19_hd2_sdetFALSE",0.000219281569444876,131,132,7
"reps19_hd2_sdetFALSE",0.000640217368054598,133,134,8
File renamed without changes.
Loading

0 comments on commit fa68a25

Please sign in to comment.