From 2a4f69b3017664346f37fe6750c510cd8e67b4fe Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 21 Sep 2023 17:48:26 +0200 Subject: [PATCH] pmf_max -> pmf_length --- R/create.R | 11 ++++++++--- R/dist.R | 14 ++++---------- R/get.R | 4 ++-- inst/stan/data/delays.stan | 4 ++-- inst/stan/data/simulation_delays.stan | 4 ++-- inst/stan/functions/infections.stan | 6 +++--- inst/stan/functions/pmfs.stan | 8 ++++---- 7 files changed, 25 insertions(+), 26 deletions(-) diff --git a/R/create.R b/R/create.R index 0e5098b74..841ed71b4 100644 --- a/R/create.R +++ b/R/create.R @@ -669,13 +669,18 @@ create_stan_delays <- function(..., weight = 1) { ret$types_id <- array(ret$types_id) ## map delays to identifiers ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1) - ## map pmfs - ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1) + ## get non zero length delay pmf lengths + ret$np_pmf_groups <- array( + c(0, cumsum( + combined_delays$np_pmf_length[combined_delays$np_pmf_length > 0]) + ) + 1 + ) + ## calculate total np pmf length + ret$np_pmf_length <- sum(combined_delays$np_pmf_length) ## assign prior weights ret$weight <- array(rep(weight, ret$n_p)) ## remove auxiliary variables ret$fixed <- NULL - ret$np_pmf_length <- NULL names(ret) <- paste("delay", names(ret), sep = "_") ret <- c(ret, ids) diff --git a/R/dist.R b/R/dist.R index 85a6c0550..bb4c9ed69 100644 --- a/R/dist.R +++ b/R/dist.R @@ -969,9 +969,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0, n = 0, n_p = 0, n_np = 0, - np_pmf_max = 0, np_pmf = numeric(0), - np_pmf_length = integer(0), fixed = integer(0) )) } else { ## parametric fixed @@ -1007,9 +1005,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0, n = 1, n_p = 0, n_np = 1, - np_pmf_max = length(pmf), np_pmf = pmf, - np_pmf_length = length(pmf), fixed = 1L )) } @@ -1025,14 +1021,13 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0, n = 1, n_p = 1, n_np = 0, - np_pmf_max = 0, np_pmf = numeric(0), - np_pmf_length = integer(0), fixed = 0L ) } ret <- purrr::map(ret, array) - sum_args <- grep("(^n$|^n_|_max$)", names(ret)) + sum_args <- grep("(^n$|^n_$)", names(ret)) + ret$np_pmf_length <- length(ret$np_pmf) ret[sum_args] <- purrr::map(ret[sum_args], sum) attr(ret, "class") <- c("list", "dist_spec") return(ret) @@ -1088,9 +1083,8 @@ dist_spec_plus <- function(e1, e2, tolerance = 0.001) { delays$fixed <- c(1, rep(0, delays$n_p)) delays$n_np <- 1 delays$n <- delays$n_p + 1 - delays$np_pmf_max <- length(delays$np_pmf) - delays$np_pmf_length <- length(delays$np_pmf) } + delays$np_pmf_length <- length(delays$np_pmf) return(delays) } @@ -1150,7 +1144,7 @@ dist_spec_plus <- function(e1, e2, tolerance = 0.001) { delays <- purrr::transpose(delays) ## convert back to arrays delays <- purrr::map(delays, function(x) array(unlist(x))) - sum_args <- grep("(^n$|^n_|_max$)", names(delays)) + sum_args <- grep("^n($|_)", names(delays)) delays[sum_args] <- purrr::map(delays[sum_args], sum) attr(delays, "class") <- c("list", "dist_spec") return(delays) diff --git a/R/get.R b/R/get.R index 544199907..2c7769910 100644 --- a/R/get.R +++ b/R/get.R @@ -291,8 +291,8 @@ get_seeding_time <- function(delays, generation_time) { ## make sure we have at least (length of total gt pmf - 1) seeding time seeding_time <- max( seeding_time, - sum(generation_time$max) + sum(generation_time$np_pmf_max) - - length(generation_time$max) - length(generation_time$np_pmf_max) + sum(generation_time$max) + sum(generation_time$np_pmf_length) - + length(generation_time$max) - length(generation_time$np_pmf_length) ) return(seeding_time) } diff --git a/inst/stan/data/delays.stan b/inst/stan/data/delays.stan index 5fa8207b8..0c624ffbd 100644 --- a/inst/stan/data/delays.stan +++ b/inst/stan/data/delays.stan @@ -7,8 +7,8 @@ array[delay_n_p] real delay_sd_sd; // prior sd of sd of delay distribution array[delay_n_p] int delay_max; // maximum delay distribution array[delay_n_p] int delay_dist; // 0 = lognormal; 1 = gamma - int delay_np_pmf_max; // number of nonparametric pmf elements - vector[delay_np_pmf_max] delay_np_pmf; // ragged array of fixed PMFs + int delay_np_pmf_length; // number of nonparametric pmf elements + vector[delay_np_pmf_length] delay_np_pmf; // ragged array of fixed PMFs array[delay_n_np + 1] int delay_np_pmf_groups; // links to ragged array array[delay_n_p] int delay_weight; diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index 0f3fe7d9e..0ceeedcaa 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -5,8 +5,8 @@ array[n, delay_n_p] real delay_sd; // prior sd of sd of delay distribution array[delay_n_p] int delay_max; // maximum delay distribution array[delay_n_p] int delay_dist; // 0 = lognormal; 1 = gamma - int delay_np_pmf_max; // number of nonparametric pmf elements - vector[delay_np_pmf_max] delay_np_pmf; // ragged array of fixed PMFs + int delay_np_pmf_length; // number of nonparametric pmf elements + vector[delay_np_pmf_length] delay_np_pmf; // ragged array of fixed PMFs array[delay_n_np + 1] int delay_np_pmf_groups; // links to ragged array array[delay_n_p] int delay_weight; diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index db1de31a7..b7790c582 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -2,15 +2,15 @@ // for a single time point real update_infectiousness(vector infections, vector gt_rev_pmf, int seeding_time, int index){ - int gt_max = num_elements(gt_rev_pmf); + int gt_length = num_elements(gt_rev_pmf); // work out where to start the convolution of past infections with the // generation time distribution: (current_time - maximal generation time) if // that is >= 1, otherwise 1 - int inf_start = max(1, (index + seeding_time - gt_max + 1)); + int inf_start = max(1, (index + seeding_time - gt_length + 1)); // work out where to end the convolution: current_time int inf_end = (index + seeding_time); // number of indices of the generation time to sum over (inf_end - inf_start + 1) - int pmf_accessed = min(gt_max, index + seeding_time); + int pmf_accessed = min(gt_length, index + seeding_time); // calculate the elements of the convolution real new_inf = dot_product( infections[inf_start:inf_end], tail(gt_rev_pmf, pmf_accessed) diff --git a/inst/stan/functions/pmfs.stan b/inst/stan/functions/pmfs.stan index adc27d53e..03fbec1f5 100644 --- a/inst/stan/functions/pmfs.stan +++ b/inst/stan/functions/pmfs.stan @@ -36,10 +36,10 @@ vector discretised_pmf(real mu, real sigma, int n, int dist) { // reverse a mf vector reverse_mf(vector pmf) { - int max_pmf = num_elements(pmf); - vector[max_pmf] rev_pmf; - for (d in 1:max_pmf) { - rev_pmf[d] = pmf[max_pmf - d + 1]; + int pmf_length = num_elements(pmf); + vector[pmf_length] rev_pmf; + for (d in 1:pmf_length) { + rev_pmf[d] = pmf[pmf_length - d + 1]; } return rev_pmf; }