Skip to content

Commit

Permalink
Switching to broadcasting random walk (#747)
Browse files Browse the repository at this point in the history
* setup broadcasting for rw

* make wider changes to make broadcasting approach easier

* update interface and test code

* add unit tests for Rt_opts

* write unit tests for create_rt_data

* Update NEWS.md

* revert setup changes

* update docs

* catch initialisation issue

* Update NEWS.md

Co-authored-by: James Azam <[email protected]>

---------

Co-authored-by: James Azam <[email protected]>
  • Loading branch information
seabbs and jamesmbaazam authored Aug 29, 2024
1 parent 8e0bf2a commit 39cdaff
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 36 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam.
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.
Expand All @@ -28,6 +29,7 @@

- Updated the documentation of the dots argument of the `stan_sampling_opts()` to add that the dots are passed to `cmdstanr::sample()`. By @jamesmbaazam in #699 and reviewed by @sbfnk.
- `generation_time_opts()` has been shortened to `gt_opts()` to make it easier to specify. Calls to both functions are equivalent. By @jamesmbaazam in #698 and reviewed by @seabbs and @sbfnk .
- Added stan documentation for `update_rt()`. By @seabbs in #747 and reviewed by @jamesmbaazam.

# EpiNow2 1.5.2

Expand Down
27 changes: 21 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,20 @@ create_future_rt <- function(future = c("latest", "project", "estimate"),
#'
#' # using breakpoints
#' create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10))
#'
#' # using random walk
#' create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10))
#' }
create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
delay = 0, horizon = 0) {

# Define if GP is on or off
if (is.null(rt)) {
rt <- rt_opts(
use_rt = FALSE,
future = "project",
gp_on = "R0"
gp_on = "R0",
rw = 0
)
}
# define future Rt arguments
Expand All @@ -279,24 +284,34 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
)
# apply random walk
if (rt$rw != 0) {
breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0)
if (is.null(breakpoints)) {
stop("breakpoints must be supplied when using random walk")
}

breakpoints <- seq_along(breakpoints)
breakpoints <- floor(breakpoints / rt$rw)
if (!(rt$future == "project")) {
max_bps <- length(breakpoints) - horizon + future_rt$from
if (max_bps < length(breakpoints)) {
breakpoints[(max_bps + 1):length(breakpoints)] <- 0
breakpoints[(max_bps + 1):length(breakpoints)] <- breakpoints[max_bps]
}
}
}else {
breakpoints <- cumsum(breakpoints)
}
# check breakpoints
if (is.null(breakpoints) || sum(breakpoints) == 0) {

if (sum(breakpoints) == 0) {
rt$use_breakpoints <- FALSE
}
# add a shift for 0 effect in breakpoints
breakpoints <- breakpoints + 1

# map settings to underlying gp stan requirements
rt_data <- list(
r_mean = rt$prior$mean,
r_sd = rt$prior$sd,
estimate_r = as.numeric(rt$use_rt),
bp_n = ifelse(rt$use_breakpoints, sum(breakpoints, na.rm = TRUE), 0),
bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0),
breakpoints = breakpoints,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ parameters{
array[estimate_r] real initial_infections ; // seed infections
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate
array[bp_n > 0 ? 1 : 0] real<lower = 0> bp_sd; // standard deviation of breakpoint effect
array[bp_n] real bp_effects; // Rt breakpoint effects
vector[bp_n] bp_effects; // Rt breakpoint effects
// observation model

vector<lower = delay_params_lower>[delay_params_length] delay_params; // delay parameters
Expand Down
66 changes: 44 additions & 22 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
// update a vector of Rts
/**
* Update a vector of effective reproduction numbers (Rt) based on
* an intercept, breakpoints (i.e. a random walk), and a Gaussian
* process.
*
* @param t Length of the time series
* @param log_R Logarithm of the base reproduction number
* @param noise Vector of Gaussian process noise values
* @param bps Array of breakpoint indices
* @param bp_effects Vector of breakpoint effects
* @param stationary Flag indicating whether the Gaussian process is stationary
* (1) or non-stationary (0)
* @return A vector of length t containing the updated Rt values
*/
vector update_Rt(int t, real log_R, vector noise, array[] int bps,
array[] real bp_effects, int stationary) {
vector bp_effects, int stationary) {
// define control parameters
int bp_n = num_elements(bp_effects);
int bp_c = 0;
int gp_n = num_elements(noise);
// define result vectors
vector[t] bp = rep_vector(0, t);
vector[t] gp = rep_vector(0, t);
vector[t] R;
// initialise breakpoints
// initialise intercept
vector[t] R = rep_vector(log_R, t);
//initialise breakpoints + rw
if (bp_n) {
for (s in 1:t) {
if (bps[s]) {
bp_c += bps[s];
bp[s] = bp_effects[bp_c];
}
}
bp = cumulative_sum(bp);
vector[bp_n + 1] bp0;
bp0[1] = 0;
bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects);
R = R + bp0[bps];
}
//initialise gaussian process
if (gp_n) {
vector[t] gp = rep_vector(0, t);
if (stationary) {
gp[1:gp_n] = noise;
// fix future gp based on last estimated
Expand All @@ -31,18 +39,31 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps,
gp[2:(gp_n + 1)] = noise;
gp = cumulative_sum(gp);
}
R = R + gp;
}
// Calculate Rt
R = rep_vector(log_R, t) + bp + gp;
R = exp(R);
return(R);

return exp(R);
}
// Rt priors

/**
* Calculate the log-probability of the reproduction number (Rt) priors
*
* @param log_R Logarithm of the base reproduction number
* @param initial_infections Array of initial infection values
* @param initial_growth Array of initial growth rates
* @param bp_effects Vector of breakpoint effects
* @param bp_sd Array of breakpoint standard deviations
* @param bp_n Number of breakpoints
* @param seeding_time Time point at which seeding occurs
* @param r_logmean Log-mean of the prior distribution for the base reproduction number
* @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number
* @param prior_infections Prior mean for initial infections
* @param prior_growth Prior mean for initial growth rates
*/
void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth,
array[] real bp_effects, array[] real bp_sd, int bp_n, int seeding_time,
vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time,
real r_logmean, real r_logsd, real prior_infections,
real prior_growth) {
// prior on R
log_R ~ normal(r_logmean, r_logsd);
//breakpoint effects on Rt
if (bp_n > 0) {
Expand All @@ -51,6 +72,7 @@ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_g
}
// initial infections
initial_infections ~ normal(prior_infections, 0.2);

if (seeding_time > 1) {
initial_growth ~ normal(prior_growth, 0.2);
}
Expand Down
2 changes: 1 addition & 1 deletion man/EpiNow2-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/create_rt_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions tests/testthat/test-create_rt_date.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
test_that("create_rt_data returns expected default values", {
result <- create_rt_data()

expect_type(result, "list")
expect_equal(result$r_mean, 1)
expect_equal(result$r_sd, 1)
expect_equal(result$estimate_r, 1)
expect_equal(result$bp_n, 0)
expect_equal(result$breakpoints, numeric(0))
expect_equal(result$future_fixed, 1)
expect_equal(result$fixed_from, 0)
expect_equal(result$pop, 0)
expect_equal(result$stationary, 0)
expect_equal(result$future_time, 0)
})

test_that("create_rt_data handles NULL rt input correctly", {
result <- create_rt_data(rt = NULL)

expect_equal(result$estimate_r, 0)
expect_equal(result$future_fixed, 0)
expect_equal(result$stationary, 1)
})

test_that("create_rt_data handles custom rt_opts correctly", {
custom_rt <- rt_opts(
prior = list(mean = 2, sd = 0.5),
use_rt = FALSE,
rw = 0,
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
)

result <- create_rt_data(rt = custom_rt, horizon = 7)

expect_equal(result$r_mean, 2)
expect_equal(result$r_sd, 0.5)
expect_equal(result$estimate_r, 0)
expect_equal(result$pop, 1000000)
expect_equal(result$stationary, 1)
expect_equal(result$future_time, 7)
})

test_that("create_rt_data handles breakpoints correctly", {
result <- create_rt_data(rt_opts(use_breakpoints = TRUE),
breakpoints = c(1, 0, 1, 0, 1))

expect_equal(result$bp_n, 3)
expect_equal(result$breakpoints, c(2, 2, 3, 3, 4))
})

test_that("create_rt_data handles random walk correctly", {
result <- create_rt_data(rt_opts(rw = 2),
breakpoints = rep(1, 10))

expect_equal(result$bp_n, 5)
expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 5, 5, 6))
})

test_that("create_rt_data throws error for invalid inputs", {
expect_error(create_rt_data(rt_opts(rw = 2)),
"breakpoints must be supplied when using random walk")
})

test_that("create_rt_data handles future projections correctly", {
result <- create_rt_data(rt_opts(future = "project"), horizon = 7)

expect_equal(result$future_fixed, 0)
expect_equal(result$fixed_from, 0)
expect_equal(result$future_time, 7)
})

test_that("create_rt_data handles zero sum breakpoints", {
result <- create_rt_data(rt_opts(use_breakpoints = TRUE),
breakpoints = rep(0, 5))

expect_equal(result$bp_n, 0)
})

test_that("create_rt_data adjusts breakpoints for horizon", {
result <- create_rt_data(rt_opts(rw = 2, future = "latest"),
breakpoints = rep(1, 10),
horizon = 3)

expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 4, 4, 4))
})
61 changes: 61 additions & 0 deletions tests/testthat/test-rt_opts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
test_that("rt_opts returns expected default values", {
result <- rt_opts()

expect_s3_class(result, "rt_opts")
expect_equal(result$prior, list(mean = 1, sd = 1))
expect_true(result$use_rt)
expect_equal(result$rw, 0)
expect_true(result$use_breakpoints)
expect_equal(result$future, "latest")
expect_equal(result$pop, 0)
expect_equal(result$gp_on, "R_t-1")
})

test_that("rt_opts handles custom inputs correctly", {
result <- rt_opts(
prior = list(mean = 2, sd = 0.5),
use_rt = FALSE,
rw = 7,
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
)

expect_equal(result$prior, list(mean = 2, sd = 0.5))
expect_false(result$use_rt)
expect_equal(result$rw, 7)
expect_true(result$use_breakpoints) # Should be TRUE when rw > 0
expect_equal(result$future, "project")
expect_equal(result$pop, 1000000)
expect_equal(result$gp_on, "R0")
})

test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", {
result <- rt_opts(rw = 3, use_breakpoints = FALSE)
expect_true(result$use_breakpoints)
})

test_that("rt_opts throws error for invalid prior", {
expect_error(rt_opts(prior = list(mean = 1)),
"prior must have both a mean and sd specified")
expect_error(rt_opts(prior = list(sd = 1)),
"prior must have both a mean and sd specified")
})

test_that("rt_opts validates gp_on argument", {
expect_error(rt_opts(gp_on = "invalid"), "must be one")
})

test_that("rt_opts returns object of correct class", {
result <- rt_opts()
expect_s3_class(result, "rt_opts")
expect_true("list" %in% class(result))
})

test_that("rt_opts handles edge cases correctly", {
result <- rt_opts(rw = 0.1, pop = -1)
expect_equal(result$rw, 0.1)
expect_equal(result$pop, -1)
expect_true(result$use_breakpoints)
})
12 changes: 6 additions & 6 deletions tests/testthat/test-stan-rt.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,29 @@ test_that("update_Rt works when Rt is fixed", {
})
test_that("update_Rt works when Rt is fixed but a breakpoint is present", {
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2),
c(1.2, 1.33, rep(1.47, 3))
)
})
test_that("update_Rt works when Rt is variable and a breakpoint is present", {
expect_equal(
round(update_Rt(5, log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.20, 1.33, 1.62, 1.79, 1.98)
)
})

0 comments on commit 39cdaff

Please sign in to comment.