Skip to content

Commit

Permalink
pmf_max -> pmf_length
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Oct 4, 2023
1 parent 419f806 commit 2a4f69b
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 26 deletions.
11 changes: 8 additions & 3 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -669,13 +669,18 @@ create_stan_delays <- function(..., weight = 1) {
ret$types_id <- array(ret$types_id)
## map delays to identifiers
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)
## map pmfs
ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1)
## get non zero length delay pmf lengths
ret$np_pmf_groups <- array(
c(0, cumsum(
combined_delays$np_pmf_length[combined_delays$np_pmf_length > 0])
) + 1
)
## calculate total np pmf length
ret$np_pmf_length <- sum(combined_delays$np_pmf_length)
## assign prior weights
ret$weight <- array(rep(weight, ret$n_p))
## remove auxiliary variables
ret$fixed <- NULL
ret$np_pmf_length <- NULL

names(ret) <- paste("delay", names(ret), sep = "_")
ret <- c(ret, ids)
Expand Down
14 changes: 4 additions & 10 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
n = 0,
n_p = 0,
n_np = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
fixed = integer(0)
))
} else { ## parametric fixed
Expand Down Expand Up @@ -1007,9 +1005,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
n = 1,
n_p = 0,
n_np = 1,
np_pmf_max = length(pmf),
np_pmf = pmf,
np_pmf_length = length(pmf),
fixed = 1L
))
}
Expand All @@ -1025,14 +1021,13 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
n = 1,
n_p = 1,
n_np = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
fixed = 0L
)
}
ret <- purrr::map(ret, array)
sum_args <- grep("(^n$|^n_|_max$)", names(ret))
sum_args <- grep("(^n$|^n_$)", names(ret))
ret$np_pmf_length <- length(ret$np_pmf)
ret[sum_args] <- purrr::map(ret[sum_args], sum)
attr(ret, "class") <- c("list", "dist_spec")
return(ret)
Expand Down Expand Up @@ -1088,9 +1083,8 @@ dist_spec_plus <- function(e1, e2, tolerance = 0.001) {
delays$fixed <- c(1, rep(0, delays$n_p))
delays$n_np <- 1
delays$n <- delays$n_p + 1
delays$np_pmf_max <- length(delays$np_pmf)
delays$np_pmf_length <- length(delays$np_pmf)
}
delays$np_pmf_length <- length(delays$np_pmf)
return(delays)
}

Expand Down Expand Up @@ -1150,7 +1144,7 @@ dist_spec_plus <- function(e1, e2, tolerance = 0.001) {
delays <- purrr::transpose(delays)
## convert back to arrays
delays <- purrr::map(delays, function(x) array(unlist(x)))
sum_args <- grep("(^n$|^n_|_max$)", names(delays))
sum_args <- grep("^n($|_)", names(delays))
delays[sum_args] <- purrr::map(delays[sum_args], sum)
attr(delays, "class") <- c("list", "dist_spec")
return(delays)
Expand Down
4 changes: 2 additions & 2 deletions R/get.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ get_seeding_time <- function(delays, generation_time) {
## make sure we have at least (length of total gt pmf - 1) seeding time
seeding_time <- max(
seeding_time,
sum(generation_time$max) + sum(generation_time$np_pmf_max) -
length(generation_time$max) - length(generation_time$np_pmf_max)
sum(generation_time$max) + sum(generation_time$np_pmf_length) -
length(generation_time$max) - length(generation_time$np_pmf_length)
)
return(seeding_time)
}
4 changes: 2 additions & 2 deletions inst/stan/data/delays.stan
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
array[delay_n_p] real<lower = 0> delay_sd_sd; // prior sd of sd of delay distribution
array[delay_n_p] int<lower = 1> delay_max; // maximum delay distribution
array[delay_n_p] int<lower = 0> delay_dist; // 0 = lognormal; 1 = gamma
int<lower = 0> delay_np_pmf_max; // number of nonparametric pmf elements
vector<lower = 0, upper = 1>[delay_np_pmf_max] delay_np_pmf; // ragged array of fixed PMFs
int<lower = 0> delay_np_pmf_length; // number of nonparametric pmf elements
vector<lower = 0, upper = 1>[delay_np_pmf_length] delay_np_pmf; // ragged array of fixed PMFs
array[delay_n_np + 1] int<lower = 1> delay_np_pmf_groups; // links to ragged array
array[delay_n_p] int<lower = 0> delay_weight;

Expand Down
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_delays.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
array[n, delay_n_p] real<lower = 0> delay_sd; // prior sd of sd of delay distribution
array[delay_n_p] int<lower = 1> delay_max; // maximum delay distribution
array[delay_n_p] int<lower = 0> delay_dist; // 0 = lognormal; 1 = gamma
int<lower = 0> delay_np_pmf_max; // number of nonparametric pmf elements
vector<lower = 0, upper = 1>[delay_np_pmf_max] delay_np_pmf; // ragged array of fixed PMFs
int<lower = 0> delay_np_pmf_length; // number of nonparametric pmf elements
vector<lower = 0, upper = 1>[delay_np_pmf_length] delay_np_pmf; // ragged array of fixed PMFs
array[delay_n_np + 1] int<lower = 1> delay_np_pmf_groups; // links to ragged array
array[delay_n_p] int delay_weight;

Expand Down
6 changes: 3 additions & 3 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
// for a single time point
real update_infectiousness(vector infections, vector gt_rev_pmf,
int seeding_time, int index){
int gt_max = num_elements(gt_rev_pmf);
int gt_length = num_elements(gt_rev_pmf);
// work out where to start the convolution of past infections with the
// generation time distribution: (current_time - maximal generation time) if
// that is >= 1, otherwise 1
int inf_start = max(1, (index + seeding_time - gt_max + 1));
int inf_start = max(1, (index + seeding_time - gt_length + 1));
// work out where to end the convolution: current_time
int inf_end = (index + seeding_time);
// number of indices of the generation time to sum over (inf_end - inf_start + 1)
int pmf_accessed = min(gt_max, index + seeding_time);
int pmf_accessed = min(gt_length, index + seeding_time);
// calculate the elements of the convolution
real new_inf = dot_product(
infections[inf_start:inf_end], tail(gt_rev_pmf, pmf_accessed)
Expand Down
8 changes: 4 additions & 4 deletions inst/stan/functions/pmfs.stan
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ vector discretised_pmf(real mu, real sigma, int n, int dist) {

// reverse a mf
vector reverse_mf(vector pmf) {
int max_pmf = num_elements(pmf);
vector[max_pmf] rev_pmf;
for (d in 1:max_pmf) {
rev_pmf[d] = pmf[max_pmf - d + 1];
int pmf_length = num_elements(pmf);
vector[pmf_length] rev_pmf;
for (d in 1:pmf_length) {
rev_pmf[d] = pmf[pmf_length - d + 1];
}
return rev_pmf;
}
Expand Down

0 comments on commit 2a4f69b

Please sign in to comment.