Skip to content

Commit

Permalink
Merge pull request #26 from poissonconsulting/dev
Browse files Browse the repository at this point in the history
add predict/plot functionality for ML fit objects
  • Loading branch information
sebdalgarno authored Apr 2, 2024
2 parents 2e6d92b + 5dd6e5b commit 23f1707
Show file tree
Hide file tree
Showing 63 changed files with 839 additions and 70 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ S3method(bb_plot_month,data.frame)
S3method(bb_plot_month_survival,bboufit_survival)
S3method(bb_plot_month_survival,data.frame)
S3method(bb_plot_year,bboufit)
S3method(bb_plot_year,bboufit_ml)
S3method(bb_plot_year,data.frame)
S3method(bb_plot_year_recruitment,bboufit_recruitment)
S3method(bb_plot_year_recruitment,data.frame)
Expand Down Expand Up @@ -44,6 +45,8 @@ S3method(predict,bboufit_survival)
S3method(print,bboufit)
S3method(print,bboufit_ml)
S3method(rhat,bboufit)
S3method(samples,bboufit)
S3method(samples,bboufit_ml)
S3method(summary,bboufit)
S3method(summary,bboufit_ml)
S3method(tidy,bboufit)
Expand All @@ -69,6 +72,8 @@ export(bb_predict_recruitment)
export(bb_predict_recruitment_trend)
export(bb_predict_survival)
export(bb_predict_survival_trend)
export(bb_priors_recruitment)
export(bb_priors_survival)
export(coef)
export(converged)
export(esr)
Expand All @@ -86,6 +91,7 @@ export(nterms)
export(pars)
export(predict)
export(rhat)
export(samples)
export(tidy)
import(chk)
import(glue)
Expand Down
3 changes: 0 additions & 3 deletions R/extract.R

This file was deleted.

5 changes: 3 additions & 2 deletions R/fit-recruitment.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ bb_fit_recruitment_ml <- function(

attrs <- list(
nobs = nrow(data$data),
converged = !convergence_fail
converged = !convergence_fail,
year_trend = year_trend
)

.attrs_bboufit_ml(fit) <- attrs

fit$data <- data$data
fit$model_code <- model$getCode()
class(fit) <- c("bboufit_ml")
class(fit) <- c("bboufit_recruitment", "bboufit_ml")
fit
}
5 changes: 3 additions & 2 deletions R/fit-survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,14 @@ bb_fit_survival_ml <- function(data,

attrs <- list(
nobs = nrow(data$data),
converged = !convergence_fail
converged = !convergence_fail,
year_trend = year_trend
)

.attrs_bboufit_ml(fit) <- attrs

fit$data <- data$data
fit$model_code <- model$getCode()
class(fit) <- c("bboufit_ml")
class(fit) <- c("bboufit_survival", "bboufit_ml")
fit
}
3 changes: 2 additions & 1 deletion R/getters.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

.attrs_bboufit_ml <- function(fit) {
attrs <- attributes(fit)
attrs[c("nobs", "converged")]
attrs[c("nobs", "converged", "year_trend")]
}

`.attrs_bboufit<-` <- function(fit, value) {
Expand All @@ -64,5 +64,6 @@
`.attrs_bboufit_ml<-` <- function(fit, value) {
.converged_bboufit(fit) <- value$converged
.nobs_bboufit(fit) <- value$nobs
.year_trend_bboufit(fit) <- value$year_trend
fit
}
15 changes: 10 additions & 5 deletions R/plot-month.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ bb_plot_month.data.frame <- function(x, ...) {
check_data(x, values = list(
Month = c(1L, 12L),
estimate = c(0, Inf),
lower = c(0, Inf),
upper = c(0, Inf)
lower = c(0, Inf, NA),
upper = c(0, Inf, NA)
))

breaks2 <- seq(2, 12, by = 2)
Expand All @@ -27,13 +27,18 @@ bb_plot_month.data.frame <- function(x, ...) {
# this deals with empty data.frame since scale_x_discrete doesn't like it
gp <- ggplot(data = x) +
aes(x = .data$Month, y = .data$estimate, ymin = .data$lower, ymax = .data$upper) +
geom_pointrange() +
xlab("Month")


if(any(is.na(x$lower))){
gp <- gp + ggplot2::geom_point()
} else {
gp <- gp + geom_pointrange()
}

if (length(x$Month)) {
return(gp + scale_x_discrete(breaks = breaks2, labels = month.abb[breaks2], drop = FALSE))
}

gp
}

Expand Down
17 changes: 12 additions & 5 deletions R/plot-year-trend.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@ plot_year_trend <- function(x, ...) {
check_data(x, values = list(
CaribouYear = 1L,
estimate = c(0, Inf),
lower = c(0, Inf),
upper = c(0, Inf)
lower = c(0, Inf, NA),
upper = c(0, Inf, NA)
))

ggplot(data = x) +
gp <- ggplot(data = x) +
aes(
x = as.integer(.data$CaribouYear),
y = .data$estimate
) +
geom_line() +
xlab(" Caribou Year")

if(any(is.na(x$lower))){
return(gp)
}

gp +
geom_ribbon(aes(
ymin = .data$lower,
ymax = .data$upper
), alpha = 0.2) +
xlab(" Caribou Year")
), alpha = 0.2)

}

#' Plot Annual Survival Trend
Expand Down
22 changes: 18 additions & 4 deletions R/plot-year.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@ bb_plot_year.data.frame <- function(x, ...) {
check_data(x, values = list(
CaribouYear = 1L,
estimate = c(0, Inf),
lower = c(0, Inf),
upper = c(0, Inf)
lower = c(0, Inf, NA),
upper = c(0, Inf, NA)
))

ggplot(data = x) +
gp <- ggplot(data = x) +
aes(
x = as.integer(.data$CaribouYear),
y = .data$estimate,
ymin = .data$lower,
ymax = .data$upper
) +
geom_pointrange() +
xlab("Caribou Year")

if(any(is.na(x$lower))){
return(gp + ggplot2::geom_point())
}
gp + geom_pointrange()

}

#' @describeIn bb_plot_year Plot annual estimates for a bboufit object.
Expand All @@ -38,3 +43,12 @@ bb_plot_year.bboufit <- function(x, conf_level = 0.95, estimate = median, ...) {
x <- predict(x, conf_level = conf_level, estimate = estimate)
bb_plot_year(x)
}

#' @describeIn bb_plot_year Plot annual estimates for a bboufit_ml object.
#' @inheritParams params
#' @export
bb_plot_year.bboufit_ml <- function(x, conf_level = 0.95, estimate = median, ...) {
chk_unused(...)
x <- predict(x, conf_level = conf_level, estimate = estimate)
bb_plot_year(x)
}
6 changes: 3 additions & 3 deletions R/predict-growth.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
predict_lambda <- function(survival, recruitment){
.chk_fit(survival)
chkor_vld(.vld_fit(survival), .vld_fit_ml(survival))
chk_s3_class(survival, "bboufit_survival")
.chk_fit(recruitment)
chkor_vld(.vld_fit(recruitment), .vld_fit_ml(recruitment))
chk_s3_class(recruitment, "bboufit_recruitment")

pred_sur <- predict_survival(survival, year = TRUE, month = FALSE)
pred_rec <- predict_recruitment(recruitment, year = TRUE)

Expand Down
4 changes: 2 additions & 2 deletions R/predict-trend.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ bb_predict_recruitment_trend <- function(recruitment,
conf_level = 0.95,
estimate = median,
sig_fig = 5) {
.chk_fit(recruitment)
chkor_vld(.vld_fit(recruitment), .vld_fit_ml(recruitment))
chk_s3_class(recruitment, "bboufit_recruitment")
.chk_year_trend(recruitment)
chk_range(conf_level)
Expand Down Expand Up @@ -54,7 +54,7 @@ bb_predict_survival_trend <- function(survival,
conf_level = 0.95,
estimate = median,
sig_fig = 5) {
.chk_fit(survival)
chkor_vld(.vld_fit(survival), .vld_fit_ml(survival))
chk_s3_class(survival, "bboufit_survival")
.chk_year_trend(survival)
chk_range(conf_level)
Expand Down
4 changes: 2 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ bb_predict_recruitment <- function(recruitment,
conf_level = 0.95,
estimate = median,
sig_fig = 3) {
.chk_fit(recruitment)
chkor_vld(.vld_fit(recruitment), .vld_fit_ml(recruitment))
chk_s3_class(recruitment, "bboufit_recruitment")
chk_flag(year)
chk_range(conf_level)
Expand Down Expand Up @@ -143,7 +143,7 @@ bb_predict_survival <- function(survival,
conf_level = 0.95,
estimate = median,
sig_fig = 3) {
.chk_fit(survival)
chkor_vld(.vld_fit(survival), .vld_fit_ml(survival))
chk_s3_class(survival, "bboufit_survival")
chk_flag(year)
chk_flag(month)
Expand Down
88 changes: 88 additions & 0 deletions R/priors.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
priors_survival <- function(){
c(
b0_mu = 4,
b0_sd = 2,
bYear_mu = 0,
bYear_sd = 1,
bAnnual_sd = 1,
sAnnual_rate = 1,
sMonth_rate = 1
)
}

priors_recruitment <- function(){
c(
b0_mu = -1.4,
b0_sd = 0.5,
bYear_mu = 0,
bYear_sd = 1,
bAnnual_sd = 1,
sAnnual_rate = 1,
adult_female_proportion_alpha = 65,
adult_female_proportion_beta = 35
)
}

#' Survival model default priors
#'
#' Prior distribution parameters and default values for survival model parameters.
#'
#' Intercept
#'
#' `b0 ~ Normal(mu = b0_mu, sd = b0_sd)`
#'
#' Year Trend
#'
#' `bYear ~ Normal(mu = bYear_mu, sd = bYear_sd)`
#'
#' Year fixed effect
#'
#' `bAnnual ~ Normal(mu = 0, sd = bAnnual_sd)`
#'
#' Standard deviation of annual random effect
#'
#' `sAnnual ~ Exponential(rate = sAnnual_rate)`
#'
#' Standard deviation of month random effect
#'
#' `sMonth ~ Exponential(rate = sMonth_rate)`
#'
#' @return A named vector.
#' @export
#'
#' @examples bb_priors_survival()
bb_priors_survival <- function() {
priors_survival()
}

#' Recruitment model default priors
#'
#' Prior distribution parameters and default values for recruitment model parameters.
#'
#' Intercept
#'
#' `b0 ~ Normal(mu = b0_mu, sd = b0_sd)`
#'
#' Year Trend
#'
#' `bYear ~ Normal(mu = bYear_mu, sd = bYear_sd)`
#'
#' Year fixed effect
#'
#' `bAnnual ~ Normal(mu = 0, sd = bAnnual_sd)`
#'
#' Standard deviation of annual random effect
#'
#' `sAnnual ~ Exponential(rate = sAnnual_rate)`
#'
#' Adult female proportion
#'
#' `adult_female_proportion ~ Beta(alpha = adult_female_proportion_alpha, beta = adult_female_proportion_beta)`
#'
#' @return A named vector.
#' @export
#'
#' @examples bb_priors_survival()
bb_priors_recruitment <- function() {
priors_recruitment()
}
22 changes: 22 additions & 0 deletions R/samples.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#' Get MCMC samples
#'
#' Get MCMC samples from Nimble model.
#' @inheritParams params
#' @export
samples <- function(x) {
UseMethod("samples")
}

#' @describeIn samples Get MCMC samples from bboufit object.
#'
#' @export
samples.bboufit <- function(x) {
x$samples
}

#' @describeIn samples Create MCMC samples (1 iteration, 1 chain) from bboufit_ml object.
#'
#' @export
samples.bboufit_ml <- function(x) {
mcmcr::as.mcmcr(estimates(x))
}
Binary file modified R/sysdata.rda
Binary file not shown.
25 changes: 0 additions & 25 deletions R/terms.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@ params_survival <- function() {
)
}

priors_survival <- function() {
c(
b0_mu = 4,
b0_sd = 2,
bYear_mu = 0,
bYear_sd = 1,
sAnnual_rate = 1,
sMonth_rate = 1,
bAnnual_sd = 1
)
}

inits_survival <- function() {
c(
b0 = 4,
Expand All @@ -40,19 +28,6 @@ params_recruitment <- function() {
)
}

priors_recruitment <- function() {
c(
b0_mu = -1.4,
b0_sd = 0.5,
bYear_mu = 0,
bYear_sd = 1,
sAnnual_rate = 1,
adult_female_proportion_alpha = 65,
adult_female_proportion_beta = 35,
bAnnual_sd = 1
)
}

inits_recruitment <- function() {
c(
b0 = -1.4,
Expand Down
Loading

0 comments on commit 23f1707

Please sign in to comment.