Skip to content

Commit

Permalink
Merge pull request #21 from josephsdavid/additional-wrappers
Browse files Browse the repository at this point in the history
Additional wrappers
  • Loading branch information
ngupta23 authored Apr 3, 2020
2 parents 127d598 + 8bdd671 commit ab58730
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 70 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tswgewrapped
Title: Helpful wrappers for 'tswge', 'vars' and 'nnfor' time series packages
Version: 1.8.10.1
Version: 1.8.10.2
Authors@R: c(
person("David", "Josephs", email = "[email protected]", role = c("aut", "cre")),
person("Nikhil", "Gupta", email = "[email protected]", role = c("aut")))
Expand Down
102 changes: 60 additions & 42 deletions R/ModelBuildNNforCaret.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,50 +64,47 @@ ModelBuildNNforCaret = R6::R6Class(



#' @description Returns the dependent variable data only
#' @return The dependent variable data only
get_data_var_interest = function(){return(self$get_data()[, private$get_var_interest()])},

# #' @description Returns the dependent variable data only
# #' @return The dependent variable data only
# get_data_var_interest = function(){return(self$get_data()[, private$get_var_interest()])},





#### General Public Methods ----

#' @description Returns the VAR model Build Summary
#' @returns A dataframe containing the following columns
#' 'Model': Name of the model
#' 'Selection': The selection criteria used for K value (AIC or BIC)
#' 'Trend': The trend argument used in the VARselect and VAR functions
#' 'SlidingASE': Whether Sliding ASE will be used for this model
#' 'Init_K': The K value recommended by the VARselect function
#' 'Final_K': The adjusted K value to take into account the smaller batch size (only when using sliding_ase)
summarize_build = function(){
results = dplyr::tribble(~Model, ~select, ~trend_type, ~season, ~p, ~SigVar, ~OriginalVar, ~Lag, ~MaxLag)

for (name in names(private$get_models())){
results = results %>%
dplyr::add_row(Model = name,
select = private$models[[name]][['select']],
trend_type = private$models[[name]][['trend_type']],
season = ifelse(is.null(private$models[[name]][['season']]), 0, private$models[[name]][['season']]),
p = private$models[[name]][['p']],
SigVar = private$models[[name]][['sigvars']][['sig_var']],
OriginalVar = private$models[[name]][['sigvars']][['original_var']],
Lag = private$models[[name]][['sigvars']][['lag']],
MaxLag = private$models[[name]][['sigvars']][['max_lag']]
)
}
#' @description Summarizes the results of all the hyperparameter combinations
#' @returns A dataframe containing the information about the different models
summarize_hyperparam_results = function(){
caret_model = self$get_final_models(subset = 'a')
return(caret_model$results)
},

#' @description Summarizes the best hyperparameter combination
#' @returns A dataframe containing the hyperparameters for the best model
summarize_best_hyperparams = function(){
caret_model = self$get_final_models(subset = 'a')
return(caret_model$bestTune)
},

#' @description Plots the ASE metric variation along the hyperparameter space
#' @param level_plot A boolean indicating whether a level plot should be shown. useful for 'grid' search (Default = TRUE).
plot_hyperparam_results = function(level_plot = TRUE){
caret_model = self$get_final_models(subset = 'a')

return(results)
print(ggplot2::ggplot(caret_model))

if (level_plot == TRUE){
# Opt 3 (Useful for grid searches not for random)
# lattice::trellis.par.set(caretTheme())
plot(caret_model, metric = "ASE", plotType = "level",
scales = list(x = list(rot = 90)))
}
},

#' @description Returns a final models
#' @param subset The subset of models to get.
#' 'a': All models (Default)
#' 'r': Only the recommended models
#' @return A named list of models
#' @return If subset = 'a', returns the caret model object
#' If subset = 'r', returns just the nnfor model
get_final_models = function(subset = 'a'){
if (subset != 'a' & subset != 'r'){
warning("The subset value mentioned is not correct. Allowed values are 'a', or 'r. The default value 'a' will be used")
Expand All @@ -121,6 +118,27 @@ ModelBuildNNforCaret = R6::R6Class(
return(private$get_models()$finalModel)
}

},

#' @description Summarizes the entire build process
#' @param level_plot A boolean indicating whether a level plot should be shown. useful for 'grid' search (Default = TRUE).
summarize_build = function(level_plot = TRUE){
cat("\n\n------------------------------")
cat("\nHyperparameter Tuning Results:")
cat("\n------------------------------\n\n")
print(self$summarize_hyperparam_results())

self$plot_hyperparam_results(level_plot = level_plot)

cat("\n\n---------------------")
cat("\nBest Hyperparameters:")
cat("\n---------------------\n\n")
print(self$summarize_best_hyperparams())

cat("\n\n--------------")
cat("\nFinal Model:")
cat("\n--------------\n\n")
print(self$get_final_models(subset = 'r'))
}

),
Expand Down Expand Up @@ -270,22 +288,22 @@ ModelBuildNNforCaret = R6::R6Class(
verbose = as.logical(private$get_verbose()),
parallel = private$get_parallel())

print(fitControl)
# print(fitControl)

# http://sshaikh.org/2015/05/06/parallelize-machine-learning-in-r-with-multi-core-cpus/
if (private$get_parallel() == TRUE){
num_cores = parallel::detectCores()
cl = parallel::makeCluster(ifelse(num_cores <= 2, 1, num_cores - 2)) # Leave 2 out
cl = parallel::makeCluster(ifelse(num_cores <= 1, 1, num_cores - 1)) # Leave 1 out
doParallel::registerDoParallel(cl)
}

form = as.formula(paste(private$get_var_interest(), ".", sep=" ~ "))

print(paste0("Formula: ", form))
print("Grid: ")
print(private$get_grid())
print(paste0("Tune Length: ", private$get_tune_length()))
print(paste0("Frequency: ", private$get_m()))
# print(paste0("Formula: ", form))
# print("Grid: ")
# print(private$get_grid())
# print(paste0("Tune Length: ", private$get_tune_length()))
# print(paste0("Frequency: ", private$get_m()))

tictoc::tic("- Total Time for training: ")

Expand Down Expand Up @@ -356,7 +374,7 @@ ModelBuildNNforCaret = R6::R6Class(

get_fit_control = function(initialWindow, h, search = "random", verbose = TRUE, parallel = TRUE){

print(paste0("get_fit_control >> verbose: ", verbose))
# print(paste0("get_fit_control >> verbose: ", verbose))

fitControl = caret::trainControl(method = "timeslice",
horizon = h,
Expand Down
2 changes: 1 addition & 1 deletion R/ModelCompareUnivariate.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ ModelCompareUnivariate = R6::R6Class(
s = private$get_models()[[name]][['s']],
n.ahead = self$get_n.ahead(),
batch_size = private$get_models()[[name]][['batch_size']],
step_n.ahead = step_n.ahead)
step_n.ahead = step_n.ahead, verbose = private$get_verbose())

return (res)
},
Expand Down
10 changes: 8 additions & 2 deletions R/sliding_ase.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @param batch_size Window Size used
#' @param n.ahead last n.ahead data points in each batch will be used for prediction and ASE calculations
#' @param step_n.ahead Whether to step each batch by n.ahead values (Default = FALSE)
#' @param verbose How much to print during the model building and other processes (Default = 0)
#' @param ... any additional arguments to be passed to the forecast functions (e.g. max.p for sigplusnoise model, lambda for ARUMA models)
#' @return Named list
#' 'ASEs' - ASE values
Expand All @@ -26,6 +27,7 @@ sliding_ase_univariate = function(x,
linear = NA, freq = NA, # Signal + Noise arguments
n.ahead = NA, batch_size = NA, # Forecasting specific arguments
step_n.ahead = TRUE,
verbose = 0,
...) # max.p (sigplusnoise), lambda (ARUMA)
{
# Sliding CV ... batches are mutually exclusive
Expand Down Expand Up @@ -73,7 +75,9 @@ sliding_ase_univariate = function(x,
num_batches = floor((n-batch_size)/n.ahead) + 1
}

cat(paste("\nNumber of batches expected: ", num_batches))
if (verbose >= 1){
cat(paste("\nNumber of batches expected: ", num_batches))
}

ASEs = numeric(num_batches)
time_test_start = numeric(num_batches)
Expand Down Expand Up @@ -193,7 +197,9 @@ sliding_ase_var = function(data, var_interest,
num_batches = floor((n-batch_size)/n.ahead) + 1
}

cat(paste("\nNumber of batches expected: ", num_batches))
if (verbose >= 1){
cat(paste("\nNumber of batches expected: ", num_batches))
}

ASEs = numeric(num_batches)
time_test_start = numeric(num_batches)
Expand Down
Binary file added build/tswgewrapped_1.8.10.2.tar.gz
Binary file not shown.
74 changes: 52 additions & 22 deletions man/ModelBuildNNforCaret.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/sliding_ase_univariate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions tests/testthat/test-BuildNNforCaret.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
## TODO: Then compare above to manual test in nnfor
## TODO: Check reproducibility of 2 runs with caret

## TODO: For now, we are predicting with the actual future values of xreg.
## In the future, change this to use forecasted values of xreg
## This can probably be achieved by passing xreg_name_predicted in the data and adjusting the source_caret_nnfor file accordingly

# test_that("Random Parallel", {
# # http://r-pkgs.had.co.nz/tests.html
# # skip_on_cran()
#
#
# # Load Data
# data = USeconomic
#
#
# library(caret)
#
# # Random Parallel
Expand All @@ -21,6 +25,11 @@
# seed = 1,
# verbose = 1)
#
# model$summarize_hyperparam_results()
# model$summarize_best_hyperparams()
# model$plot_hyperparam_results()
# model$summarize_build()
#
# # # #testthat::expect_equal(good, TRUE)
#
# })
Expand Down

0 comments on commit ab58730

Please sign in to comment.