diff --git a/DESCRIPTION b/DESCRIPTION index c1df779..818e4e9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,10 +1,10 @@ Package: tswgewrapped Title: Helpful wrappers for 'tswge', 'vars' and 'nnfor' time series packages -Version: 1.8.10.2 +Version: 1.8.10.3 Authors@R: c( person("David", "Josephs", email = "josephsd@smu.edu", role = c("aut", "cre")), person("Nikhil", "Gupta", email = "guptan@smu.edu", role = c("aut"))) -Description: This package provides several helpful wrappers for the already useful 'tswge', 'vars' and 'nnfor' package. In the future, this package intends to move away from the tswge backend, to be faster, with more readable source code. +Description: This package provides several helpful wrappers for the already useful 'tswge', 'vars' and 'nnfor' package. License: AGPL-3 Encoding: UTF-8 LazyData: true diff --git a/NAMESPACE b/NAMESPACE index bf68e05..0e25eba 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ export("%>%") export(ModelBuildMultivariateVAR) export(ModelBuildNNforCaret) export(ModelCompareMultivariateVAR) +export(ModelCompareNNforCaret) export(ModelCompareUnivariate) export(MultivariateEDA) export(aic5) diff --git a/R/ModelCompareBase.R b/R/ModelCompareBase.R index ce70dcf..4c8dcaa 100644 --- a/R/ModelCompareBase.R +++ b/R/ModelCompareBase.R @@ -12,7 +12,7 @@ ModelCompareBase = R6::R6Class( #' @description #' Initialize an object to compare several Univatiate Time Series Models #' @param data The dataframe containing the time series realizations (data should not contain time index) - #' @param mdl_list A names list of all models (see format below) + #' @param mdl_list A named list of all models (see format below) #' @param n.ahead The number of observations used to calculate ASE or forecast ahead #' @param batch_size If any of the models used sliding ase method, #' then this number indicates the batch size to use @@ -161,17 +161,29 @@ ModelCompareBase = R6::R6Class( res = private$get_sliding_ase_results(name = name, step_n.ahead = step_n.ahead) ## Inplace - private$models[[name]][['ASEs']] = res$ASEs - private$models[[name]][['time_test_start']] = res$time_test_start - private$models[[name]][['time_test_end']] = res$time_test_end - private$models[[name]][['batch_num']] = res$batch_num - private$models[[name]][['f']] = res$f - private$models[[name]][['ll']] = res$ll - private$models[[name]][['ul']] = res$ul - private$models[[name]][['time.forecasts']] = res$time.forecasts + # private$models[[name]][['ASEs']] = res$ASEs + # private$models[[name]][['time_test_start']] = res$time_test_start + # private$models[[name]][['time_test_end']] = res$time_test_end + # private$models[[name]][['batch_num']] = res$batch_num + # + # private$models[[name]][['f']] = res$f + # private$models[[name]][['ll']] = res$ll + # private$models[[name]][['ul']] = res$ul + # private$models[[name]][['time.forecasts']] = res$time.forecasts + + private$models[[name]]$ASEs = res$ASEs + private$models[[name]]$time_test_start = res$time_test_start + private$models[[name]]$time_test_end = res$time_test_end + private$models[[name]]$batch_num = res$batch_num + + private$models[[name]]$f = res$f + private$models[[name]]$ll = res$ll + private$models[[name]]$ul = res$ul + private$models[[name]]$time.forecasts = res$time.forecasts - private$models[[name]][['metric_has_been_computed']] = TRUE + # private$models[[name]][['metric_has_been_computed']] = TRUE + private$models[[name]]$metric_has_been_computed = TRUE } else{ @@ -224,72 +236,75 @@ ModelCompareBase = R6::R6Class( #' @param only_sliding If TRUE, this will only plot the batch forecasts #' for the models that used window ASE calculations plot_batch_forecasts = function(only_sliding = TRUE){ - - results.forecasts = self$get_tabular_metrics(ases = FALSE) - - model_subset = c("Realization") - if (only_sliding){ - for (name in names(private$get_models())){ - if (private$models[[name]][['sliding_ase']] == TRUE){ + if (only_sliding == TRUE & private$any_sliding_ase() == FALSE){ + message("None of your models are using a sliding ASE calculation, hence nothing will be plotted") + } + else{ + results.forecasts = self$get_tabular_metrics(ases = FALSE) + + model_subset = c("Realization") + if (only_sliding){ + for (name in names(private$get_models())){ + if (private$models[[name]][['sliding_ase']] == TRUE){ + model_subset = c(model_subset, name) + } + } + } + else{ + # Add all models + for (name in names(private$get_models())){ model_subset = c(model_subset, name) } } - } - else{ - # Add all models - for (name in names(private$get_models())){ - model_subset = c(model_subset, name) + + results.forecasts = results.forecasts %>% + dplyr::filter(Model %in% model_subset) + + # https://stackoverflow.com/questions/9968975/make-the-background-of-a-graph-different-colours-in-different-regions + + # Get Batch Boundaries + results.ases = self$get_tabular_metrics(ases = TRUE) + if (private$any_sliding_ase()){ + for (name in names(private$get_models())){ + if (private$models[[name]][['sliding_ase']] == TRUE){ + results.batches = results.ases %>% + dplyr::filter(Model == name) + break() + } + } } - } - - results.forecasts = results.forecasts %>% - dplyr::filter(Model %in% model_subset) - - # https://stackoverflow.com/questions/9968975/make-the-background-of-a-graph-different-colours-in-different-regions - - # Get Batch Boundaries - results.ases = self$get_tabular_metrics(ases = TRUE) - if (private$any_sliding_ase()){ - for (name in names(private$get_models())){ - if (private$models[[name]][['sliding_ase']] == TRUE){ + else{ + # No model has sliding ASE, so just pick the 1st one + for (name in names(private$get_models())){ results.batches = results.ases %>% - dplyr::filter(Model == name) + dplyr::filter(Model == name) break() } } + + rects = data.frame(xstart = results.batches[['Time_Test_Start']], + xend = results.batches[['Time_Test_End']], + Batch = rep(1, length(results.batches[['Batch']]))) + + + p = ggplot2::ggplot() + + ggplot2::geom_rect(data = rects, ggplot2::aes(xmin = xstart, xmax = xend, ymin = -Inf, ymax = Inf, fill = Batch), alpha = 0.1, show.legend = FALSE) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x = Time, y = f, color = Model), size = 1) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x = Time, y = f, color = Model), size = 0.75) + + ggplot2::ylab("Forecasts") + + print(p) + + p = ggplot2::ggplot() + + ggplot2::geom_rect(data = rects, ggplot2::aes(xmin = xstart, xmax = xend, ymin = -Inf, ymax = Inf, fill = Batch), alpha = 0.1, show.legend = FALSE) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 1) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 1) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 0.75) + + ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x=Time, y=ul, color = Model), size = 0.75) + + ggplot2::ylab("Upper and Lower Forecast Limits (95%)") + + print(p) } - else{ - # No model has sliding ASE, so just pick the 1st one - for (name in names(private$get_models())){ - results.batches = results.ases %>% - dplyr::filter(Model == name) - break() - } - } - - rects = data.frame(xstart = results.batches[['Time_Test_Start']], - xend = results.batches[['Time_Test_End']], - Batch = rep(1, length(results.batches[['Batch']]))) - - - p = ggplot2::ggplot() + - ggplot2::geom_rect(data = rects, ggplot2::aes(xmin = xstart, xmax = xend, ymin = -Inf, ymax = Inf, fill = Batch), alpha = 0.1, show.legend = FALSE) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x = Time, y = f, color = Model), size = 1) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x = Time, y = f, color = Model), size = 0.75) + - ggplot2::ylab("Forecasts") - - print(p) - - - p = ggplot2::ggplot() + - ggplot2::geom_rect(data = rects, ggplot2::aes(xmin = xstart, xmax = xend, ymin = -Inf, ymax = Inf, fill = Batch), alpha = 0.1, show.legend = FALSE) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 1) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model == 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 1) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x=Time, y=ll, color = Model), size = 0.75) + - ggplot2::geom_line(results.forecasts %>% dplyr::filter(Model != 'Realization'), mapping = ggplot2::aes(x=Time, y=ul, color = Model), size = 0.75) + - ggplot2::ylab("Upper and Lower Forecast Limits (95%)") - - print(p) }, @@ -297,63 +312,65 @@ ModelCompareBase = R6::R6Class( #' @param only_sliding If TRUE, this will only plot the ASEs for #' the models that used window ASE calculations plot_batch_ases = function(only_sliding = TRUE){ - - requireNamespace("patchwork") - - model_subset = c() - - if (only_sliding){ - for (name in names(private$get_models())){ - if (private$models[[name]][['sliding_ase']] == TRUE){ - model_subset = c(model_subset, name) - } - } + if (only_sliding == TRUE & private$any_sliding_ase() == FALSE){ + message("None of your models are using a sliding ASE calculation, hence nothing will be plotted") } else{ - # Add all models - for (name in names(private$get_models())){ - model_subset = c(model_subset, name) - } - } - - ASEs = self$get_tabular_metrics(ases = TRUE) %>% - dplyr::filter(Model %in% model_subset) %>% - tidyr::gather("Index", "Time", -Model, -ASE, -Batch) - - all_time = NA - - for (name in model_subset){ - if (all(is.na(all_time))){ - all_time = data.frame(Time = seq(1, private$get_len_x()), - Model = rep(name, private$get_len_x()), - ASE = 0) + requireNamespace("patchwork") + + model_subset = c() + + if (only_sliding){ + for (name in names(private$get_models())){ + if (private$models[[name]][['sliding_ase']] == TRUE){ + model_subset = c(model_subset, name) + } + } } else{ - all_time = rbind(all_time, data.frame(Time = seq(1, private$get_len_x()), - Model = rep(name, private$get_len_x()), - ASE = 0)) + # Add all models + for (name in names(private$get_models())){ + model_subset = c(model_subset, name) + } + } + + ASEs = self$get_tabular_metrics(ases = TRUE) %>% + dplyr::filter(Model %in% model_subset) %>% + tidyr::gather("Index", "Time", -Model, -ASE, -Batch) + + all_time = NA + + for (name in model_subset){ + if (all(is.na(all_time))){ + all_time = data.frame(Time = seq(1, private$get_len_x()), + Model = rep(name, private$get_len_x()), + ASE = 0) + } + else{ + all_time = rbind(all_time, data.frame(Time = seq(1, private$get_len_x()), + Model = rep(name, private$get_len_x()), + ASE = 0)) + } + + all_time = all_time %>% + dplyr::mutate_if(is.factor, as.character) } - all_time = all_time %>% - dplyr::mutate_if(is.factor, as.character) + results = dplyr::left_join(all_time, ASEs, by = c("Time", "Model")) %>% + dplyr::mutate(ASE = ASE.x + ASE.y) %>% + dplyr::group_by(Model) %>% + tidyr::fill(.data$ASE, .direction = "down") + + data = data.frame(Time = seq(1, private$get_len_x()), Data = self$get_data_var_interest()) + + g1 = ggplot2::ggplot() + + ggplot2::geom_line(data, mapping = ggplot2::aes(x = Time, y = Data), size = 1) + + g2 = ggplot2::ggplot() + + ggplot2::geom_line(results, mapping = ggplot2::aes(x = Time, y = ASE, color = Model), size = 1) + + print(g1/g2) } - - results = dplyr::left_join(all_time, ASEs, by = c("Time", "Model")) %>% - dplyr::mutate(ASE = ASE.x + ASE.y) %>% - dplyr::group_by(Model) %>% - tidyr::fill(.data$ASE, .direction = "down") - - data = data.frame(Time = seq(1, private$get_len_x()), Data = self$get_data_var_interest()) - - g1 = ggplot2::ggplot() + - ggplot2::geom_line(data, mapping = ggplot2::aes(x = Time, y = Data), size = 1) - - g2 = ggplot2::ggplot() + - ggplot2::geom_line(results, mapping = ggplot2::aes(x = Time, y = ASE, color = Model), size = 1) - - print(g1/g2) - - }, #' @description Plots the histogram of the ASE values for the models @@ -560,8 +577,8 @@ ModelCompareBase = R6::R6Class( }, evaluate_models = function(){ - private$build_models(verbose = private$get_verbose()) - private$evaluate_xIC() + # private$build_models(verbose = private$get_verbose()) + # private$evaluate_xIC() self$compute_metrics(step_n.ahead = private$get_step_n.ahead()) print(self$summarize_build()) }, diff --git a/R/ModelCompareMultivariateVAR.R b/R/ModelCompareMultivariateVAR.R index 309f54b..22c9288 100644 --- a/R/ModelCompareMultivariateVAR.R +++ b/R/ModelCompareMultivariateVAR.R @@ -16,7 +16,7 @@ ModelCompareMultivariateVAR = R6::R6Class( #' Initialize an object to compare several Univatiate Time Series Models #' @param data The dataframe containing the time series realizations (data should not contain time index) #' @param var_interest The output variable of interest (dependent variable) - #' @param mdl_list A names list of all models (see format below) + #' @param mdl_list A named list of all models (see format below) #' @param n.ahead The number of observations used to calculate ASE or forecast ahead #' @param batch_size If any of the models used sliding ase method, #' then this number indicates the batch size to use @@ -77,9 +77,9 @@ ModelCompareMultivariateVAR = R6::R6Class( #' 'Final_K': The adjusted K value to take into account the smaller batch size (only when using sliding_ase) summarize_build = function(){ results = dplyr::tribble(~Model, ~Trend, ~Season, ~SlidingASE, ~Init_K, ~Final_K) - + for (name in names(private$get_models())){ - results = results %>% + results = results %>% dplyr::add_row(Model = name, Trend = private$models[[name]][['trend_type']], Season = ifelse(is.null(private$models[[name]][['season']]), 0, private$models[[name]][['season']]), @@ -87,9 +87,9 @@ ModelCompareMultivariateVAR = R6::R6Class( Init_K = private$models[[name]][['k_initial']], Final_K = private$models[[name]][['k_final']] ) - + } - + return(results) } @@ -203,14 +203,6 @@ ModelCompareMultivariateVAR = R6::R6Class( return(results) }, - build_models = function(verbose = 0){ - - }, - - evaluate_xIC = function(){ - - }, - validate_k = function(k, batch_size, season, col_names){ # https://stats.stackexchange.com/questions/234975/how-many-endogenous-variables-in-a-var-model-with-120-observations ## num_vars (in code) = K in the equation in link diff --git a/R/ModelCompareNNforCaret.R b/R/ModelCompareNNforCaret.R new file mode 100644 index 0000000..30452d0 --- /dev/null +++ b/R/ModelCompareNNforCaret.R @@ -0,0 +1,226 @@ +#' @title R6 class ModelCompareMultivariate +#' +#' @export +ModelCompareNNforCaret = R6::R6Class( + classname = "ModelCompareNNforCaret", + inherit = ModelCompareBase, + cloneable = TRUE, + lock_objects=F, + lock_class=F, + + #### Public Methods ---- + public=list( + #### Constructor ---- + + #' @description + #' Initialize an object to compare several Univatiate Time Series Models + #' @param data The dataframe containing the time series realizations (data should not contain time index) + #' @param var_interest The output variable of interest (dependent variable) + #' @param mdl_list A single caret model (which may contain results of grid or random search) + #' @param verbose How much to print during the model building and other processes (Default = 0) + #' @return A new `ModelCompareNNforCaret` object. + initialize = function(data = NA, var_interest = NA, mdl_list, verbose = 0) + { + private$set_caret_model(caret_model = mdl_list) + n.ahead = private$compute_n.ahead() + batch_size = private$compute_batch_size() + private$set_var_interest(var_interest = var_interest) + super$initialize(data = data, mdl_list = mdl_list, + n.ahead = n.ahead, batch_size = batch_size, step_n.ahead = TRUE, + verbose = verbose) + + }, + + #### Getters and Setters ---- + + #' @description Returns the dependent variable name + #' @return The dependent variable name + get_var_interest = function(){return(private$var_interest)}, + + #' @description Returns the dependent variable data only + #' @return The dependent variable data only + get_data_var_interest = function(){return(self$get_data()[, self$get_var_interest()])}, + + #### General Public Methods ---- + + #' @description Not applicable for the nnfor::mlp models, since we are passing already build models + summarize_build = function(){ + + } + + ), + + + #### Private Methods ---- + private = list( + var_interest = NA, + caret_model = NA, + + set_var_interest = function(var_interest){private$var_interest = var_interest}, + + get_data_subset = function(col_names){ + return(self$get_data() %>% dplyr::select(col_names)) + }, + + set_caret_model = function(caret_model){private$caret_model = caret_model}, + + get_caret_model = function(){ + return(private$caret_model) + }, + + get_len_x = function(){ + return(nrow(self$get_data())) + }, + + clean_model_input = function(mdl_list, batch_size){ + ## mdl_list is actually the caret model here (which technically is also a list :)) + + results_with_id = mdl_list[['results']] %>% + private$add_model_id() + + sliding_ase = ifelse(nrow(mdl_list[['pred']] %>% dplyr::select(Resample) %>% unique()) > 1, TRUE, FALSE) + + rv_mdl_list = list() + + for (i in seq_len(nrow(mdl_list[['results']]))){ + subset = results_with_id %>% dplyr::slice(i) + + name = subset %>% purrr::pluck("ID") + + rv_mdl_list[[name]][['reps']] = subset %>% purrr::pluck("reps") + rv_mdl_list[[name]][['hd']] = subset %>% purrr::pluck("hd") + rv_mdl_list[[name]][['allow.det.season']] = subset %>% purrr::pluck("allow.det.season") + rv_mdl_list[[name]][['sliding_ase']] = sliding_ase + rv_mdl_list[[name]][['metric_has_been_computed']] = FALSE + + } + + return(rv_mdl_list) + }, + + get_sliding_ase_results = function(name, step_n.ahead){ + + ase_data = private$clean_resample_info(data = private$get_caret_model()[['resample']], subset = name) %>% + dplyr::mutate(time_test_start = Resample + 1) %>% + dplyr::mutate(time_test_end = Resample + self$get_n.ahead()) %>% + dplyr::mutate(batch_num = (Resample + self$get_n.ahead() - self$get_batch_size())/self$get_n.ahead() + 1) + + pred_data = private$clean_resample_info(data = private$get_caret_model()[['pred']], subset = name) + + res = list() + res$ASEs = ase_data %>% purrr::pluck("ASE") + res$time_test_start = ase_data %>% purrr::pluck("time_test_start") + res$time_test_end = ase_data %>% purrr::pluck("time_test_end") + res$batch_num = ase_data %>% purrr::pluck("batch_num") + res$f = pred_data %>% purrr::pluck("pred") + res$ll = pred_data %>% purrr::pluck("pred") + res$ul = pred_data %>% purrr::pluck("pred") + res$time.forecasts = pred_data %>% purrr::pluck("rowIndex") + + return (res) + }, + + compute_simple_forecasts = function(lastn){ + ## TODO: Needed for NNFOR + ## But add an argument xreg + ## Used by plot_simple_forecasts in the base class + + message("This function is not supported for nnfor::mlp at this time.") + + results = dplyr::tribble(~Model, ~Time, ~f, ~ll, ~ul) + # + # if (lastn == FALSE){ + # data_start = 1 + # data_end = private$get_len_x() + # train_data = self$get_data()[data_start:data_end, ] + # + # } + # else{ + # data_start = 1 + # data_end = private$get_len_x() - self$get_n.ahead() + # train_data = self$get_data()[data_start:data_end, ] + # } + # + # from = data_end + 1 + # to = data_end + self$get_n.ahead() + # + # # Define Train Data + # + # for (name in names(private$get_models())){ + # + # var_interest = self$get_var_interest() + # k = private$get_models()[[name]][['k_final']] + # trend_type = private$get_models()[[name]][['trend_type']] + # + # # Fit model for the batch + # varfit = vars::VAR(train_data, p=k, type=trend_type) + # + # # Forecast for the batch + # forecasts = stats::predict(varfit, n.ahead=self$get_n.ahead()) + # forecasts = forecasts$fcst[[var_interest]] ## Get the forecasts only for the dependent variable + # + # results = results %>% + # dplyr::add_row(Model = name, + # Time = (from:to), + # f = forecasts[, 'fcst'], + # ll = forecasts[, 'lower'], + # ul = forecasts[, 'upper']) + # + # } + # + return(results) + }, + + add_model_id = function(dataframe){ + ## For a dataframe containing the columns reps, hd, and allow.det.season, + ## this will add the model ID to each row + return( + dataframe %>% + assertr::verify(assertr::has_all_names("reps", "hd", "allow.det.season")) %>% + dplyr::mutate(ID = paste0("reps", reps, "_hd", hd, "_sdet", allow.det.season)) + ) + + }, + + compute_batch_size = function(){ + data = private$get_data_to_compute_batch_info() + # Batch Size Definition = Training + Test + batch_size = data %>% purrr::pluck("Resample") %>% min() + private$compute_n.ahead() + return(batch_size) + }, + + compute_n.ahead = function(){ + data = private$get_data_to_compute_batch_info() + n.ahead = nrow(private$get_caret_model()[['trainingData']]) - data %>% purrr::pluck("Resample") %>% max() + return(n.ahead) + }, + + get_data_to_compute_batch_info = function(){ + data = private$get_caret_model()[['resample']] %>% + private$add_model_id() + + one_name = data %>% dplyr::slice(1) %>% purrr::pluck("ID") + + data = private$clean_resample_info(data = private$get_caret_model()[['resample']], subset = one_name) + + return(data) + }, + + clean_resample_info = function(data, subset = NA){ + # data is a dataframe with columns 'reps', 'hd', 'allow.det.season' and 'Resample' + # subset = ID/name of the model (single) to get; NA will get all + data = data %>% + assertr::verify(assertr::has_all_names("reps", "hd", "allow.det.season", "Resample")) %>% + private$add_model_id() %>% + dplyr::mutate(subset = ifelse(is.na(subset), ID, subset)) %>% + dplyr::filter(ID == subset) %>% + dplyr::mutate(Resample = as.numeric(gsub("Training", "", Resample))) %>% + dplyr::arrange(Resample) + + return(data) + } + + ) + +) + diff --git a/R/ModelCompareUnivariate.R b/R/ModelCompareUnivariate.R index ffb729e..3331dc4 100644 --- a/R/ModelCompareUnivariate.R +++ b/R/ModelCompareUnivariate.R @@ -72,7 +72,7 @@ ModelCompareUnivariate = R6::R6Class( #' @description #' Initialize an object to compare several Univatiate Time Series Models #' @param data A Univariate Time Series Realization - #' @param mdl_list A names list of all models (see format below) + #' @param mdl_list A named list of all models (see format below) #' @param n.ahead The number of observations used to calculate ASE or forecast ahead #' @param batch_size If any of the models used sliding ase method, #' then this number indicates the batch size to use @@ -282,17 +282,9 @@ ModelCompareUnivariate = R6::R6Class( set_var_interest = function(var_interest = NA){ ## Do nothing. There is only 1 variable of interest. - }, - - build_models = function(verbose = 0){ - ## Do nothing. We are expecting the build model (all parameters known) to be passed - }, - - evaluate_xIC = function(){ - ## Do nothing. Since we are expecting the build model to be passed, the AIC and BIC would be external to this object. } - + ) diff --git a/Readme.md b/Readme.md index 7b6f26b..cf70db9 100644 --- a/Readme.md +++ b/Readme.md @@ -174,10 +174,26 @@ Check out the vignette 'ModelBuildMultivariateVAR' # Multivariate Time Series Model (VAR) Comparison Check out the vignette 'ModelCompareMultivariateVAR' -* Supports comparig the performance of multiple multivariate VAR models +* Supports comparing the performance of multiple multivariate VAR models * Suppport for simple forecasts and plotting * Support for Batch ASE calculations and plotting * Statistical Comparison of models (when using batch ASE method) * Histogram of model comparison (ASE values) * Tabular metrics for manual anaysis (if needed) +# Time Series with nnfor::mlp (Neural Network) Model Building +Check out the vignette 'ModelCompareNNforCaret' +* Builds the model with the caret framework. +* Suppport for predefined or random grid search +* Supports parallel processing using multiple cores to speed up the grid search +* Support for sliding ASE while building models + +# Time Series with nnfor::mlp (Neural Network) Model Comparison +Check out the vignette 'ModelCompareNNforCaret' +* Supports comparing the performance of multiple nnfor::mlp() submodels built by caret +* Does not suppport simple forecasts and plotting yet (planned for the future) +* Support for Batch ASE calculations and plotting +* Statistical Comparison of models (when using batch ASE method) +* Histogram of model comparison (ASE values) +* Tabular metrics for manual anaysis (if needed) + diff --git a/build/tswgewrapped_1.8.10.3.tar.gz b/build/tswgewrapped_1.8.10.3.tar.gz new file mode 100644 index 0000000..bdd17da Binary files /dev/null and b/build/tswgewrapped_1.8.10.3.tar.gz differ diff --git a/inst/extdata/caret_model_batch_ase.rds b/inst/extdata/caret_model_batch_ase.rds new file mode 100644 index 0000000..3ae4a5f Binary files /dev/null and b/inst/extdata/caret_model_batch_ase.rds differ diff --git a/inst/extdata/caret_nnfor_ases.csv b/inst/extdata/caret_nnfor_ases.csv new file mode 100644 index 0000000..add7bc4 --- /dev/null +++ b/inst/extdata/caret_nnfor_ases.csv @@ -0,0 +1,7 @@ +"Model","ASE","Time_Test_Start","Time_Test_End","Batch" +"reps15_hd5_sdetFALSE",4.49210184297525e-05,131,132,1 +"reps15_hd5_sdetFALSE",0.000605052972875598,133,134,2 +"reps15_hd5_sdetFALSE",4.6068496613423e-05,135,136,3 +"reps19_hd2_sdetFALSE",0.000123595898534739,131,132,1 +"reps19_hd2_sdetFALSE",0.000162611783204444,133,134,2 +"reps19_hd2_sdetFALSE",8.09564637458037e-05,135,136,3 diff --git a/inst/extdata/caret_nnfor_forecasts.csv b/inst/extdata/caret_nnfor_forecasts.csv new file mode 100644 index 0000000..47d260a --- /dev/null +++ b/inst/extdata/caret_nnfor_forecasts.csv @@ -0,0 +1,149 @@ +"Model","Time","f","ll","ul" +"reps15_hd5_sdetFALSE",131,8.2278766923447,8.2278766923447,8.2278766923447 +"reps15_hd5_sdetFALSE",132,8.23272068911654,8.23272068911654,8.23272068911654 +"reps15_hd5_sdetFALSE",133,8.22178448771875,8.22178448771875,8.22178448771875 +"reps15_hd5_sdetFALSE",134,8.21731960515902,8.21731960515902,8.21731960515902 +"reps15_hd5_sdetFALSE",135,8.26178320377863,8.26178320377863,8.26178320377863 +"reps15_hd5_sdetFALSE",136,8.26522139965064,8.26522139965064,8.26522139965064 +"reps19_hd2_sdetFALSE",131,8.22965763156111,8.22965763156111,8.22965763156111 +"reps19_hd2_sdetFALSE",132,8.23905707282243,8.23905707282243,8.23905707282243 +"reps19_hd2_sdetFALSE",133,8.22893585193485,8.22893585193485,8.22893585193485 +"reps19_hd2_sdetFALSE",134,8.23246915113718,8.23246915113718,8.23246915113718 +"reps19_hd2_sdetFALSE",135,8.25925374602098,8.25925374602098,8.25925374602098 +"reps19_hd2_sdetFALSE",136,8.26189894366923,8.26189894366923,8.26189894366923 +"Realization",1,7.249072901,7.249072901,7.249072901 +"Realization",2,7.245084291,7.245084291,7.245084291 +"Realization",3,7.257002707,7.257002707,7.257002707 +"Realization",4,7.271564712,7.271564712,7.271564712 +"Realization",5,7.292745534,7.292745534,7.292745534 +"Realization",6,7.303641321,7.303641321,7.303641321 +"Realization",7,7.316880348,7.316880348,7.316880348 +"Realization",8,7.325609985,7.325609985,7.325609985 +"Realization",9,7.323632657,7.323632657,7.323632657 +"Realization",10,7.328174679,7.328174679,7.328174679 +"Realization",11,7.328896866,7.328896866,7.328896866 +"Realization",12,7.33992723,7.33992723,7.33992723 +"Realization",13,7.348136979,7.348136979,7.348136979 +"Realization",14,7.347557399,7.347557399,7.347557399 +"Realization",15,7.353402177,7.353402177,7.353402177 +"Realization",16,7.33778291,7.33778291,7.33778291 +"Realization",17,7.317278807,7.317278807,7.317278807 +"Realization",18,7.322642526,7.322642526,7.322642526 +"Realization",19,7.34601021,7.34601021,7.34601021 +"Realization",20,7.369411667,7.369411667,7.369411667 +"Realization",21,7.381750929,7.381750929,7.381750929 +"Realization",22,7.400620577,7.400620577,7.400620577 +"Realization",23,7.396028498,7.396028498,7.396028498 +"Realization",24,7.404522545,7.404522545,7.404522545 +"Realization",25,7.421536531,7.421536531,7.421536531 +"Realization",26,7.4186609,7.4186609,7.4186609 +"Realization",27,7.419620362,7.419620362,7.419620362 +"Realization",28,7.411012333,7.411012333,7.411012333 +"Realization",29,7.421357046,7.421357046,7.421357046 +"Realization",30,7.43372564,7.43372564,7.43372564 +"Realization",31,7.44792609,7.44792609,7.44792609 +"Realization",32,7.470167154,7.470167154,7.470167154 +"Realization",33,7.483188172,7.483188172,7.483188172 +"Realization",34,7.493539941,7.493539941,7.493539941 +"Realization",35,7.502793366,7.502793366,7.502793366 +"Realization",36,7.501137371,7.501137371,7.501137371 +"Realization",37,7.514581753,7.514581753,7.514581753 +"Realization",38,7.528331767,7.528331767,7.528331767 +"Realization",39,7.545653985,7.545653985,7.545653985 +"Realization",40,7.552814549,7.552814549,7.552814549 +"Realization",41,7.574917763,7.574917763,7.574917763 +"Realization",42,7.583451066,7.583451066,7.583451066 +"Realization",43,7.593474944,7.593474944,7.593474944 +"Realization",44,7.597747488,7.597747488,7.597747488 +"Realization",45,7.619184323,7.619184323,7.619184323 +"Realization",46,7.633563242,7.633563242,7.633563242 +"Realization",47,7.649359235,7.649359235,7.649359235 +"Realization",48,7.672106219,7.672106219,7.672106219 +"Realization",49,7.691702484,7.691702484,7.691702484 +"Realization",50,7.694301724,7.694301724,7.694301724 +"Realization",51,7.704496416,7.704496416,7.704496416 +"Realization",52,7.709398056,7.709398056,7.709398056 +"Realization",53,7.715034394,7.715034394,7.715034394 +"Realization",54,7.72099394,7.72099394,7.72099394 +"Realization",55,7.735302225,7.735302225,7.735302225 +"Realization",56,7.740925237,7.740925237,7.740925237 +"Realization",57,7.752464076,7.752464076,7.752464076 +"Realization",58,7.769336361,7.769336361,7.769336361 +"Realization",59,7.777080182,7.777080182,7.777080182 +"Realization",60,7.776115477,7.776115477,7.776115477 +"Realization",61,7.790075491,7.790075491,7.790075491 +"Realization",62,7.791440171,7.791440171,7.791440171 +"Realization",63,7.796962542,7.796962542,7.796962542 +"Realization",64,7.792968055,7.792968055,7.792968055 +"Realization",65,7.786800945,7.786800945,7.786800945 +"Realization",66,7.785928689,7.785928689,7.785928689 +"Realization",67,7.798030524,7.798030524,7.798030524 +"Realization",68,7.788957548,7.788957548,7.788957548 +"Realization",69,7.815449164,7.815449164,7.815449164 +"Realization",70,7.81536847,7.81536847,7.81536847 +"Realization",71,7.820479659,7.820479659,7.820479659 +"Realization",72,7.820439515,7.820439515,7.820439515 +"Realization",73,7.842121658,7.842121658,7.842121658 +"Realization",74,7.861380331,7.861380331,7.861380331 +"Realization",75,7.871730802,7.871730802,7.871730802 +"Realization",76,7.890320524,7.890320524,7.890320524 +"Realization",77,7.913521017,7.913521017,7.913521017 +"Realization",78,7.916078096,7.916078096,7.916078096 +"Realization",79,7.915092569,7.915092569,7.915092569 +"Realization",80,7.923999937,7.923999937,7.923999937 +"Realization",81,7.918410289,7.918410289,7.918410289 +"Realization",82,7.921245314,7.921245314,7.921245314 +"Realization",83,7.908129773,7.908129773,7.908129773 +"Realization",84,7.899301895,7.899301895,7.899301895 +"Realization",85,7.879556401,7.879556401,7.879556401 +"Realization",86,7.889683927,7.889683927,7.889683927 +"Realization",87,7.906510399,7.906510399,7.906510399 +"Realization",88,7.920337527,7.920337527,7.920337527 +"Realization",89,7.938944891,7.938944891,7.938944891 +"Realization",90,7.943392268,7.943392268,7.943392268 +"Realization",91,7.947537169,7.947537169,7.947537169 +"Realization",92,7.957457396,7.957457396,7.957457396 +"Realization",93,7.971085754,7.971085754,7.971085754 +"Realization",94,7.987082806,7.987082806,7.987082806 +"Realization",95,8.006967388,8.006967388,8.006967388 +"Realization",96,8.004398965,8.004398965,8.004398965 +"Realization",97,8.01317766,8.01317766,8.01317766 +"Realization",98,8.044273314,8.044273314,8.044273314 +"Realization",99,8.052805762,8.052805762,8.052805762 +"Realization",100,8.065139494,8.065139494,8.065139494 +"Realization",101,8.065170924,8.065170924,8.065170924 +"Realization",102,8.064290504,8.064290504,8.064290504 +"Realization",103,8.073215919,8.073215919,8.073215919 +"Realization",104,8.071312256,8.071312256,8.071312256 +"Realization",105,8.081289494,8.081289494,8.081289494 +"Realization",106,8.057377489,8.057377489,8.057377489 +"Realization",107,8.058042456,8.058042456,8.058042456 +"Realization",108,8.070656058,8.070656058,8.070656058 +"Realization",109,8.089819841,8.089819841,8.089819841 +"Realization",110,8.086471812,8.086471812,8.086471812 +"Realization",111,8.090892523,8.090892523,8.090892523 +"Realization",112,8.076919224,8.076919224,8.076919224 +"Realization",113,8.061613042,8.061613042,8.061613042 +"Realization",114,8.064605029,8.064605029,8.064605029 +"Realization",115,8.056585284,8.056585284,8.056585284 +"Realization",116,8.058105763,8.058105763,8.058105763 +"Realization",117,8.066709797,8.066709797,8.066709797 +"Realization",118,8.088960866,8.088960866,8.088960866 +"Realization",119,8.103615263,8.103615263,8.103615263 +"Realization",120,8.121212959,8.121212959,8.121212959 +"Realization",121,8.146622142,8.146622142,8.146622142 +"Realization",122,8.159946656,8.159946656,8.159946656 +"Realization",123,8.166386709,8.166386709,8.166386709 +"Realization",124,8.170525154,8.170525154,8.170525154 +"Realization",125,8.182419511,8.182419511,8.182419511 +"Realization",126,8.188466878,8.188466878,8.188466878 +"Realization",127,8.198584448,8.198584448,8.198584448 +"Realization",128,8.205873949,8.205873949,8.205873949 +"Realization",129,8.221290758,8.221290758,8.221290758 +"Realization",130,8.219218329,8.219218329,8.219218329 +"Realization",131,8.221828349,8.221828349,8.221828349 +"Realization",132,8.225422773,8.225422773,8.225422773 +"Realization",133,8.236605891,8.236605891,8.236605891 +"Realization",134,8.248790734,8.248790734,8.248790734 +"Realization",135,8.259794578,8.259794578,8.259794578 +"Realization",136,8.274611946,8.274611946,8.274611946 diff --git a/man/ModelCompareBase.Rd b/man/ModelCompareBase.Rd index 146b0b4..41eb147 100644 --- a/man/ModelCompareBase.Rd +++ b/man/ModelCompareBase.Rd @@ -52,7 +52,7 @@ Initialize an object to compare several Univatiate Time Series Models \describe{ \item{\code{data}}{The dataframe containing the time series realizations (data should not contain time index)} -\item{\code{mdl_list}}{A names list of all models (see format below)} +\item{\code{mdl_list}}{A named list of all models (see format below)} \item{\code{n.ahead}}{The number of observations used to calculate ASE or forecast ahead} diff --git a/man/ModelCompareMultivariateVAR.Rd b/man/ModelCompareMultivariateVAR.Rd index 8f90d8d..0785eb1 100644 --- a/man/ModelCompareMultivariateVAR.Rd +++ b/man/ModelCompareMultivariateVAR.Rd @@ -76,7 +76,7 @@ Initialize an object to compare several Univatiate Time Series Models \item{\code{var_interest}}{The output variable of interest (dependent variable)} -\item{\code{mdl_list}}{A names list of all models (see format below)} +\item{\code{mdl_list}}{A named list of all models (see format below)} \item{\code{n.ahead}}{The number of observations used to calculate ASE or forecast ahead} diff --git a/man/ModelCompareNNforCaret.Rd b/man/ModelCompareNNforCaret.Rd new file mode 100644 index 0000000..61bc79a --- /dev/null +++ b/man/ModelCompareNNforCaret.Rd @@ -0,0 +1,124 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ModelCompareNNforCaret.R +\name{ModelCompareNNforCaret} +\alias{ModelCompareNNforCaret} +\title{R6 class ModelCompareMultivariate} +\description{ +R6 class ModelCompareMultivariate + +R6 class ModelCompareMultivariate +} +\section{Super class}{ +\code{\link[tswgewrapped:ModelCompareBase]{tswgewrapped::ModelCompareBase}} -> \code{ModelCompareNNforCaret} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-new}{\code{ModelCompareNNforCaret$new()}} +\item \href{#method-get_var_interest}{\code{ModelCompareNNforCaret$get_var_interest()}} +\item \href{#method-get_data_var_interest}{\code{ModelCompareNNforCaret$get_data_var_interest()}} +\item \href{#method-summarize_build}{\code{ModelCompareNNforCaret$summarize_build()}} +\item \href{#method-clone}{\code{ModelCompareNNforCaret$clone()}} +} +} +\if{html}{ +\out{
Inherited methods} +\itemize{ +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-add_models}{\code{tswgewrapped::ModelCompareBase$add_models()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-compute_metrics}{\code{tswgewrapped::ModelCompareBase$compute_metrics()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-get_batch_size}{\code{tswgewrapped::ModelCompareBase$get_batch_size()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-get_data}{\code{tswgewrapped::ModelCompareBase$get_data()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-get_n.ahead}{\code{tswgewrapped::ModelCompareBase$get_n.ahead()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-get_tabular_metrics}{\code{tswgewrapped::ModelCompareBase$get_tabular_metrics()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-keep_models}{\code{tswgewrapped::ModelCompareBase$keep_models()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-plot_batch_ases}{\code{tswgewrapped::ModelCompareBase$plot_batch_ases()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-plot_batch_forecasts}{\code{tswgewrapped::ModelCompareBase$plot_batch_forecasts()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-plot_histogram_ases}{\code{tswgewrapped::ModelCompareBase$plot_histogram_ases()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-plot_simple_forecasts}{\code{tswgewrapped::ModelCompareBase$plot_simple_forecasts()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-remove_models}{\code{tswgewrapped::ModelCompareBase$remove_models()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-set_batch_size}{\code{tswgewrapped::ModelCompareBase$set_batch_size()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-set_verbose}{\code{tswgewrapped::ModelCompareBase$set_verbose()}}\out{} +\item \out{}\href{../../tswgewrapped/html/ModelCompareBase.html#method-statistical_compare}{\code{tswgewrapped::ModelCompareBase$statistical_compare()}}\out{} +} +\out{
} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-new}{}}} +\subsection{Method \code{new()}}{ +Initialize an object to compare several Univatiate Time Series Models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ModelCompareNNforCaret$new(data = NA, var_interest = NA, mdl_list, verbose = 0)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{data}}{The dataframe containing the time series realizations (data should not contain time index)} + +\item{\code{var_interest}}{The output variable of interest (dependent variable)} + +\item{\code{mdl_list}}{A single caret model (which may contain results of grid or random search)} + +\item{\code{verbose}}{How much to print during the model building and other processes (Default = 0)} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new `ModelCompareNNforCaret` object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-get_var_interest}{}}} +\subsection{Method \code{get_var_interest()}}{ +Returns the dependent variable name +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ModelCompareNNforCaret$get_var_interest()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +The dependent variable name +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-get_data_var_interest}{}}} +\subsection{Method \code{get_data_var_interest()}}{ +Returns the dependent variable data only +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ModelCompareNNforCaret$get_data_var_interest()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +The dependent variable data only +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-summarize_build}{}}} +\subsection{Method \code{summarize_build()}}{ +Not applicable for the nnfor::mlp models, since we are passing already build models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ModelCompareNNforCaret$summarize_build()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ModelCompareNNforCaret$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/ModelCompareUnivariate.Rd b/man/ModelCompareUnivariate.Rd index 61ccc6e..3efd76e 100644 --- a/man/ModelCompareUnivariate.Rd +++ b/man/ModelCompareUnivariate.Rd @@ -122,7 +122,7 @@ Initialize an object to compare several Univatiate Time Series Models \describe{ \item{\code{data}}{A Univariate Time Series Realization} -\item{\code{mdl_list}}{A names list of all models (see format below)} +\item{\code{mdl_list}}{A named list of all models (see format below)} \item{\code{n.ahead}}{The number of observations used to calculate ASE or forecast ahead} diff --git a/man/tswgewrapped-package.Rd b/man/tswgewrapped-package.Rd index 6385b00..0e03654 100644 --- a/man/tswgewrapped-package.Rd +++ b/man/tswgewrapped-package.Rd @@ -6,7 +6,7 @@ \alias{tswgewrapped-package} \title{tswgewrapped: Helpful wrappers for 'tswge', 'vars' and 'nnfor' time series packages} \description{ -This package provides several helpful wrappers for the already useful 'tswge', 'vars' and 'nnfor' package. In the future, this package intends to move away from the tswge backend, to be faster, with more readable source code. +This package provides several helpful wrappers for the already useful 'tswge', 'vars' and 'nnfor' package. } \author{ \strong{Maintainer}: David Josephs \email{josephsd@smu.edu} diff --git a/tests/testthat.R b/tests/testthat.R index 2c8c0f2..cd08bc6 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -7,13 +7,11 @@ library(tswgewrapped) file = system.file("extdata", "USeconomic.csv", package = "tswgewrapped", mustWork = TRUE) USeconomic = read.csv(file, header = TRUE, stringsAsFactors = FALSE, check.names = FALSE) names(USeconomic) = gsub("[(|)]", "", colnames(USeconomic)) -# colnames(USeconomic) = c("logM1", "logGNP", "rs", "rl") ## Load Datasets data("airlog") data("AirPassengers") data("sunspot.classic") -# data("USeconomic") ## Perform Checks test_check("tswgewrapped") \ No newline at end of file diff --git a/tests/testthat/test-BuildNNforCaret.R b/tests/testthat/test-BuildNNforCaret.R index d8409f7..a179388 100644 --- a/tests/testthat/test-BuildNNforCaret.R +++ b/tests/testthat/test-BuildNNforCaret.R @@ -2,6 +2,7 @@ ## TODO: Run in series and parallel with random grid and random and with and without user defined grid and tuneLength ## TODO: Then compare above to manual test in nnfor ## TODO: Check reproducibility of 2 runs with caret +## TODO: Check univariate test with nnfor ## TODO: For now, we are predicting with the actual future values of xreg. ## In the future, change this to use forecasted values of xreg diff --git a/tests/testthat/test-CompareNNforCaret.R b/tests/testthat/test-CompareNNforCaret.R new file mode 100644 index 0000000..3c289e3 --- /dev/null +++ b/tests/testthat/test-CompareNNforCaret.R @@ -0,0 +1,67 @@ +## TODO: Write all unit tests + +test_that("Random Parallel", { + # http://r-pkgs.had.co.nz/tests.html + # skip_on_cran() + + + # # Load Data + file = system.file("extdata", "USeconomic.csv", package = "tswgewrapped", mustWork = TRUE) + USeconomic = read.csv(file, header = TRUE, stringsAsFactors = FALSE, check.names = FALSE) + names(USeconomic) = gsub("[(|)]", "", colnames(USeconomic)) + data = USeconomic + + # library(caret) + # + # # Random Parallel + # model = ModelBuildNNforCaret$new(data = data, var_interest = "logGNP", m = 2, + # search = 'random', + # grid = NA, tuneLength = 2, + # batch_size = 132, h = 2, + # parallel = TRUE, + # seed = 1, + # verbose = 1) + # + # model$summarize_hyperparam_results() + # model$plot_hyperparam_results() + # + # model$summarize_best_hyperparams() + # model$summarize_build() + # + # caret_model = model$get_final_models(subset = 'a') + # # saveRDS(caret_model, "caret_model_batch_ase.rds") + + # Load already saved model + file = system.file("extdata", "caret_model_batch_ase.rds", package = "tswgewrapped", mustWork = TRUE) + caret_model = readRDS(file) + + mdl_compare = ModelCompareNNforCaret$new(data = data, var_interest = 'logGNP', + mdl_list = caret_model, + verbose = 1) + + ases = mdl_compare$get_tabular_metrics() + # write.csv(ases, file = "caret_nnfor_ases.csv", row.names = FALSE) + # Load target data + ases_file = system.file("extdata", "caret_nnfor_ases.csv", package = "tswgewrapped", mustWork = TRUE) + ases_target = read.csv(ases_file, header = TRUE, stringsAsFactors = FALSE) + good1 = all.equal(as.data.frame(ases), ases_target %>% dplyr::mutate_if(is.numeric, as.double)) + testthat::expect_equal(good1, TRUE) + + forecasts = mdl_compare$get_tabular_metrics(ases = FALSE) + # write.csv(forecasts, file = "caret_nnfor_forecasts.csv", row.names = FALSE) + forecasts_file = system.file("extdata", "caret_nnfor_forecasts.csv", package = "tswgewrapped", mustWork = TRUE) + forecasts_target = read.csv(forecasts_file, header = TRUE, stringsAsFactors = FALSE) + good2 = all.equal(as.data.frame(forecasts), forecasts_target %>% dplyr::mutate_if(is.numeric, as.double)) + testthat::expect_equal(good2, TRUE) + + mdl_compare$plot_histogram_ases() + + result = mdl_compare$statistical_compare() + pval = summary(result)[[1]]$`Pr(>F)`[1] + testthat::expect_equal(round(pval,6), 0.591116) + + mdl_compare$plot_batch_forecasts() + mdl_compare$plot_batch_ases() + mdl_compare$plot_simple_forecasts() + +}) diff --git a/vignettes/ModelCompareMultivariateVAR.Rmd b/vignettes/ModelCompareMultivariateVAR.Rmd index dfdb017..c7a5a58 100644 --- a/vignettes/ModelCompareMultivariateVAR.Rmd +++ b/vignettes/ModelCompareMultivariateVAR.Rmd @@ -116,6 +116,11 @@ mdl_compare$plot_simple_forecasts(lastn = FALSE, limits = TRUE, zoom = 50) mdl_compare$plot_batch_forecasts(only_sliding = FALSE) ``` +## Plot and compare the ASEs per batch +```{r fig.width=10} +mdl_compare$plot_batch_ases() +``` + ## Raw Data and Metrics ```{r} ASEs = mdl_compare$get_tabular_metrics(ases = TRUE) diff --git a/vignettes/ModelCompareNNforCaret.Rmd b/vignettes/ModelCompareNNforCaret.Rmd new file mode 100644 index 0000000..a31bff4 --- /dev/null +++ b/vignettes/ModelCompareNNforCaret.Rmd @@ -0,0 +1,110 @@ +--- +title: "ModelCompareNNforCaret" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{ModelCompareNNforCaret} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +# Setup Libraries +```{r setup} +library(tswgewrapped) +``` + +# Load Data +```{r} +file = system.file("extdata", "USeconomic.csv", package = "tswgewrapped", mustWork = TRUE) +USeconomic = read.csv(file, header = TRUE, stringsAsFactors = FALSE, check.names = FALSE) +names(USeconomic) = gsub("[(|)]", "", colnames(USeconomic)) +data = USeconomic +``` + +# Build caret model + +**Since this process takes some time, I have commented this out for now and saved an already created caret model. However, feel free to uncomment this and run the model build process.** + +```{r} +# library(caret) +# +# # Random Parallel +# model = ModelBuildNNforCaret$new(data = data, var_interest = "logGNP", m = 2, +# search = 'random', +# grid = NA, tuneLength = 2, +# batch_size = 132, h = 2, +# parallel = TRUE, +# seed = 1, +# verbose = 1) +# +# model$summarize_hyperparam_results() +# model$plot_hyperparam_results() +# +# model$summarize_best_hyperparams() +# model$summarize_build() +# +# caret_model = model$get_final_models(subset = 'a') +``` + + +# Load already saved caret model +```{r} +file = system.file("extdata", "caret_model_batch_ase.rds", package = "tswgewrapped", mustWork = TRUE) +caret_model = readRDS(file) +``` + +# Initialize the ModelCompareMultivariateVAR object + +```{r} +mdl_compare = ModelCompareNNforCaret$new(data = data, var_interest = 'logGNP', + mdl_list = caret_model, + verbose = 1) +``` + +# Compare the models + +## Compare Histogram of ASE values +```{r fig.width = 8} +mdl_compare$plot_histogram_ases() +``` + +## Statistically Compare the models +```{r} +mdl_compare$statistical_compare() +``` + +## Simple Forecasts (with various options) + +**This is not currently supported since it needs future values to be passed and we dont have these values yet (unless we forecast them). We will add this functionality in the future.** + +```{r fig.width=8} +mdl_compare$plot_simple_forecasts() +``` + +## Plot and compare the forecasts per batch +```{r fig.width=8} +mdl_compare$plot_batch_forecasts() +``` + +## Plot and compare the ASEs per batch +```{r fig.width=8} +mdl_compare$plot_batch_ases() +``` + +## Raw Data and Metrics +```{r} +ASEs = mdl_compare$get_tabular_metrics(ases = TRUE) +print(ASEs) +``` + +```{r} +forecasts = mdl_compare$get_tabular_metrics(ases = FALSE) +print(forecasts) +``` +