diff --git a/DESCRIPTION b/DESCRIPTION index 195afb7..7d9ecac 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Encoding: UTF-8 Type: Package Package: mr.mash.alpha -Version: 0.3.25 -Date: 2024-05-06 +Version: 0.3.27 +Date: 2024-05-29 Title: Multiple Regression with Multivariate Adaptive Shrinkage Description: Provides an implementation of methods for multivariate multiple regression with adaptive shrinkage priors. @@ -18,6 +18,7 @@ License: MIT + file LICENSE Depends: R (>= 3.1.0) Imports: stats, + Matrix, Rcpp (>= 1.0.7), RcppParallel (>= 5.1.5), mvtnorm, @@ -25,8 +26,7 @@ Imports: mashr (>= 0.2.73), ebnm, flashier (>= 1.0.7), - parallel, - Rfast + parallel Suggests: testthat, varbvs, diff --git a/NAMESPACE b/NAMESPACE index 391bc1f..b80beaf 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -16,11 +16,14 @@ export(mr.mash.rss) export(predict.mr.mash) export(predict.mr.mash.rss) export(simulate_mr_mash_data) +importFrom(Matrix,crossprod) +importFrom(Matrix,diag) +importFrom(Matrix,isSymmetric) +importFrom(Matrix,t) importFrom(Rcpp,evalCpp) importFrom(RcppParallel,RcppParallelLibs) importFrom(RcppParallel,defaultNumThreads) importFrom(RcppParallel,setThreadOptions) -importFrom(Rfast,is.symmetric) importFrom(ebnm,ebnm_normal) importFrom(ebnm,ebnm_normal_scale_mixture) importFrom(flashier,flash) diff --git a/R/RcppExports.R b/R/RcppExports.R index bae2e1d..3924c9d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -13,6 +13,10 @@ inner_loop_general_rss_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S .Call('_mr_mash_alpha_inner_loop_general_rss_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads) } +inner_loop_general_rss_sparse_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads) { + .Call('_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads) +} + scale_rcpp <- function(M, a, b) { .Call('_mr_mash_alpha_scale_rcpp', PACKAGE = 'mr.mash.alpha', M, a, b) } diff --git a/R/mr_mash_rss.R b/R/mr_mash_rss.R index 5600223..ae2384e 100644 --- a/R/mr_mash_rss.R +++ b/R/mr_mash_rss.R @@ -13,7 +13,7 @@ #' @param Z p x r matrix of Z-scores from univariate #' simple linear regression. #' -#' @param R p x p correlation matrix among the variables. +#' @param R p x p dense or sparse correlation matrix among the variables. #' #' @param covY r x r covariance matrix across responses. #' @@ -171,7 +171,7 @@ #' abline(a = 0,b = 1,col = "magenta",lty = "dotted") #' #' @importFrom stats cov -#' @importFrom Rfast is.symmetric +#' @importFrom Matrix isSymmetric t diag crossprod #' @importFrom RcppParallel defaultNumThreads #' @importFrom RcppParallel setThreadOptions #' @@ -242,14 +242,14 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le if(any(is.na(Z))) stop("Z must not contain missing values.") } - if(!is.matrix(V) || !is.symmetric(V)) + if(!is.matrix(V) || !isSymmetric(V)) stop("V must be a symmetric matrix.") if(!missing(covY)){ - if(!is.matrix(covY) || !is.symmetric(covY)) + if(!is.matrix(covY) || !isSymmetric(covY)) stop("covY must be a symmetric matrix.") } - if(!is.matrix(R) || !is.symmetric(R)) - stop("R must be a symmetric matrix.") + if(!(is.matrix(R) || inherits(R,"CsparseMatrix")) || !isSymmetric(R)) + stop("R must be a dense or sparse symmetric matrix.") if(!is.list(S0)) stop("S0 must be a list.") if(!is.vector(w0)) @@ -271,6 +271,9 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le # PRE-PROCESSING STEPS # -------------------- + ###Check if R is sparse + R_is_sparse <- inherits(R,"CsparseMatrix") + ###Compute Z scores if(missing(Z)){ Z <- Bhat/Shat @@ -303,6 +306,10 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le covY <- cov2cor(V) } + if(R_is_sparse){ + XtX <- as(XtX, "symmetricMatrix") + } + YtY <- covY*(n-1) ###Check whether XtX is positive semidefinite @@ -490,7 +497,7 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le standardize=standardize, update_V=update_V, version=version, update_order=update_order, eps=eps, - nthreads=nthreads) + R_is_sparse=R_is_sparse, nthreads=nthreads) mu1_t <- ups$mu1_t S1_t <- ups$S1_t w1_t <- ups$w1_t diff --git a/R/mr_mash_rss_updates.R b/R/mr_mash_rss_updates.R index 6282d33..783d8d7 100644 --- a/R/mr_mash_rss_updates.R +++ b/R/mr_mash_rss_updates.R @@ -77,11 +77,18 @@ inner_loop_general_rss_R <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, #' inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, standardize, compute_ELBO, update_V, update_order, - eps, nthreads){ + eps, R_is_sparse, nthreads){ - out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, - standardize, compute_ELBO, update_V, update_order, - eps, nthreads) + if(R_is_sparse){ + out <- inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, + standardize, compute_ELBO, update_V, update_order, + eps, nthreads) + } else { + out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, + standardize, compute_ELBO, update_V, update_order, + eps, nthreads) + } + ###Return output if(compute_ELBO && update_V){ @@ -100,14 +107,14 @@ inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S ###Wrapper of the inner loop with R or Rcpp inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, standardize, compute_ELBO, update_V, version, - update_order, eps, nthreads){ + update_order, eps, R_is_sparse, nthreads){ if(version=="R"){ out <- inner_loop_general_rss_R(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants, standardize, compute_ELBO, update_V, update_order, eps) } else if(version=="Rcpp"){ update_order <- as.integer(update_order-1) out <- inner_loop_general_rss_Rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, simplify2array_custom(S0), precomp_quants, - standardize, compute_ELBO, update_V, update_order, eps, nthreads) + standardize, compute_ELBO, update_V, update_order, eps, R_is_sparse, nthreads) } return(out) @@ -118,18 +125,19 @@ inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, pr mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV, w0, S0, precomp_quants, compute_ELBO, standardize, update_V, version, update_order, eps, - nthreads){ + R_is_sparse, nthreads){ ##Compute expected residuals XtRbar <- XtY - XtX %*% mu1_t + XtRbar <- as.matrix(XtRbar) ##Update variational parameters, expected residuals, and ELBO components updates <- inner_loop_general_rss(n=n, XtX=XtX, XtY=XtY, XtRbar=XtRbar, mu1=mu1_t, V=V, Vinv=Vinv, w0=w0, S0=S0, precomp_quants=precomp_quants, standardize=standardize, compute_ELBO=compute_ELBO, update_V=update_V, version=version, - update_order=update_order, eps=eps, nthreads=nthreads) + update_order=update_order, eps=eps, R_is_sparse=R_is_sparse, nthreads=nthreads) mu1_t <- updates$mu1 S1_t <- updates$S1 w1_t <- updates$w1 @@ -138,6 +146,7 @@ mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV, if(compute_ELBO || update_V){ RbartRbar <- YtY - crossprod(mu1_t, XtY) - crossprod(XtY, mu1_t) + crossprod(mu1_t, XtX)%*%mu1_t + RbartRbar <- as.matrix(RbartRbar) } if(compute_ELBO && update_V){ diff --git a/man/mr.mash.Rd b/man/mr.mash.Rd index bcd2a65..8d6cb59 100644 --- a/man/mr.mash.Rd +++ b/man/mr.mash.Rd @@ -11,7 +11,7 @@ mr.mash( w0 = rep(1/(length(S0)), length(S0)), V = NULL, mu1_init = matrix(0, nrow = ncol(X), ncol = ncol(Y)), - tol = 0.0001, + tol = 1e-04, convergence_criterion = c("mu1", "ELBO"), max_iter = 5000, update_w0 = TRUE, diff --git a/man/mr.mash.rss.Rd b/man/mr.mash.rss.Rd index 98f803e..729d24f 100644 --- a/man/mr.mash.rss.Rd +++ b/man/mr.mash.rss.Rd @@ -16,7 +16,7 @@ mr.mash.rss( w0 = rep(1/(length(S0)), length(S0)), V, mu1_init = NULL, - tol = 0.0001, + tol = 1e-04, convergence_criterion = c("mu1", "ELBO"), max_iter = 5000, update_w0 = TRUE, @@ -47,7 +47,7 @@ from univariate simple linear regression.} \item{Z}{p x r matrix of Z-scores from univariate simple linear regression.} -\item{R}{p x p correlation matrix among the variables.} +\item{R}{p x p dense or sparse correlation matrix among the variables.} \item{covY}{r x r covariance matrix across responses.} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6b5fae5..5a88114 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -76,6 +76,32 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// inner_loop_general_rss_sparse_rcpp +List inner_loop_general_rss_sparse_rcpp(unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY, arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V, const arma::mat& Vinv, const arma::vec& w0, const arma::cube& S0, const List& precomp_quants_list, bool standardize, bool compute_ELBO, bool update_V, const arma::vec& update_order, double eps, unsigned int nthreads); +RcppExport SEXP _mr_mash_alpha_inner_loop_general_rss_sparse_rcpp(SEXP nSEXP, SEXP XtXSEXP, SEXP XtYSEXP, SEXP XtRbarSEXP, SEXP mu1SEXP, SEXP VSEXP, SEXP VinvSEXP, SEXP w0SEXP, SEXP S0SEXP, SEXP precomp_quants_listSEXP, SEXP standardizeSEXP, SEXP compute_ELBOSEXP, SEXP update_VSEXP, SEXP update_orderSEXP, SEXP epsSEXP, SEXP nthreadsSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< unsigned int >::type n(nSEXP); + Rcpp::traits::input_parameter< const arma::sp_mat& >::type XtX(XtXSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type XtY(XtYSEXP); + Rcpp::traits::input_parameter< arma::mat& >::type XtRbar(XtRbarSEXP); + Rcpp::traits::input_parameter< arma::mat& >::type mu1(mu1SEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type V(VSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Vinv(VinvSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type w0(w0SEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type S0(S0SEXP); + Rcpp::traits::input_parameter< const List& >::type precomp_quants_list(precomp_quants_listSEXP); + Rcpp::traits::input_parameter< bool >::type standardize(standardizeSEXP); + Rcpp::traits::input_parameter< bool >::type compute_ELBO(compute_ELBOSEXP); + Rcpp::traits::input_parameter< bool >::type update_V(update_VSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type update_order(update_orderSEXP); + Rcpp::traits::input_parameter< double >::type eps(epsSEXP); + Rcpp::traits::input_parameter< unsigned int >::type nthreads(nthreadsSEXP); + rcpp_result_gen = Rcpp::wrap(inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads)); + return rcpp_result_gen; +END_RCPP +} // scale_rcpp arma::mat scale_rcpp(const arma::mat& M, const arma::vec& a, const arma::vec& b); RcppExport SEXP _mr_mash_alpha_scale_rcpp(SEXP MSEXP, SEXP aSEXP, SEXP bSEXP) { @@ -160,6 +186,7 @@ static const R_CallMethodDef CallEntries[] = { {"_mr_mash_alpha_inner_loop_general_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rcpp, 14}, {"_mr_mash_alpha_impute_missing_Y_rcpp", (DL_FUNC) &_mr_mash_alpha_impute_missing_Y_rcpp, 5}, {"_mr_mash_alpha_inner_loop_general_rss_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_rcpp, 16}, + {"_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp, 16}, {"_mr_mash_alpha_scale_rcpp", (DL_FUNC) &_mr_mash_alpha_scale_rcpp, 3}, {"_mr_mash_alpha_scale2_rcpp", (DL_FUNC) &_mr_mash_alpha_scale2_rcpp, 3}, {"_mr_mash_alpha_rescale_post_mean_covar_rcpp", (DL_FUNC) &_mr_mash_alpha_rescale_post_mean_covar_rcpp, 3}, diff --git a/src/mr_mash_updates.cpp b/src/mr_mash_updates.cpp index f359327..f27b143 100644 --- a/src/mr_mash_updates.cpp +++ b/src/mr_mash_updates.cpp @@ -57,6 +57,15 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat cube& S1, mat& w1, double& var_part_tr_wERSS, double& neg_KL, mat& var_part_ERSS); +// Inner loop rss sprse +void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar, + mat& mu1, const mat& V, const mat& Vinv, const vec& w0, + const cube& S0, const mr_mash_precomputed_quantities& precomp_quants, + bool standardize, bool compute_ELBO, bool update_V, + const vec& update_order, double eps, unsigned int nthreads, + cube& S1, mat& w1, double& var_part_tr_wERSS, + double& neg_KL, mat& var_part_ERSS); + // FUNCTION DEFINITIONS // -------------------- @@ -357,3 +366,117 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat XtRbar -= (Xtx * trans(mu1_mix)); } } + + +// Inner loop rss sparse +// +// [[Rcpp::depends(RcppArmadillo)]] +// [[Rcpp::depends(RcppParallel)]] +// [[Rcpp::export]] +List inner_loop_general_rss_sparse_rcpp (unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY, + arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V, + const arma::mat& Vinv, const arma::vec& w0, + const arma::cube& S0, const List& precomp_quants_list, + bool standardize, bool compute_ELBO, bool update_V, + const arma::vec& update_order, double eps, unsigned int nthreads) { + unsigned int r = mu1.n_cols; + unsigned int p = mu1.n_rows; + unsigned int k = w0.n_elem; + cube S1(r,r,p); + mat w1(p,k); + mat mu1_new = mu1; + double var_part_tr_wERSS; + double neg_KL; + mat var_part_ERSS(r,r); + mr_mash_precomputed_quantities precomp_quants + (as(precomp_quants_list["S"]), + as(precomp_quants_list["S_chol"]), + as(precomp_quants_list["S1"]), + as(precomp_quants_list["SplusS0_chol"]), + as(precomp_quants_list["V_chol"]), + as(precomp_quants_list["d"]), + as(precomp_quants_list["QtimesV_chol"]), + as(precomp_quants_list["xtx"])); + inner_loop_general_rss_sparse(n, XtX, XtY, XtRbar, mu1_new, V, Vinv, w0, S0, precomp_quants, + standardize, compute_ELBO, update_V, update_order, eps, + nthreads, S1, w1, var_part_tr_wERSS, neg_KL, var_part_ERSS); + return List::create(Named("mu1") = mu1_new, + Named("S1") = S1, + Named("w1") = w1, + Named("var_part_tr_wERSS") = var_part_tr_wERSS, + Named("neg_KL") = neg_KL, + Named("var_part_ERSS") = var_part_ERSS); +} + +// Perform the inner loop rss +void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar, + mat& mu1, const mat& V, const mat& Vinv, const vec& w0, + const cube& S0, const mr_mash_precomputed_quantities& precomp_quants, + bool standardize, bool compute_ELBO, bool update_V, + const vec& update_order, double eps, unsigned int nthreads, + cube& S1, mat& w1, double& var_part_tr_wERSS, + double& neg_KL, mat& var_part_ERSS) { + unsigned int p = mu1.n_rows; + unsigned int r = mu1.n_cols; + unsigned int k = w0.n_elem; + vec Xtx(p); + vec XtRbar_j(r); + vec mu1_j(r); + vec mu1_mix(r); + mat S1_mix(r,r); + vec w1_mix(k); + double logbf_mix; + double xtx_j; + + // Initialize ELBO parameters + var_part_tr_wERSS = 0; + neg_KL = 0; + + // Initialize V parameters + var_part_ERSS.zeros(r,r); + + // Repeat for each predictor. + for (unsigned int j : update_order) { + + if (standardize) + xtx_j = n-1; + else + xtx_j = precomp_quants.xtx(j); + + Xtx = XtX.col(j); + mu1_j = trans(mu1.row(j)); + + // Disregard the jth predictor in the expected residuals. + XtRbar += (Xtx * trans(mu1_j)); + XtRbar_j = trans(XtRbar.row(j)); + + // Update the posterior quantities for the jth + // predictor. + if (standardize) + logbf_mix = bayes_mvr_mix_standardized_X_rss(n, XtRbar_j, w0, S0, precomp_quants.S, + precomp_quants.S1, + precomp_quants.SplusS0_chol, + precomp_quants.S_chol, eps, nthreads, + mu1_mix, S1_mix, w1_mix); + else + logbf_mix = bayes_mvr_mix_centered_X_rss(XtRbar_j, V, w0, S0, xtx_j, Vinv, + precomp_quants.V_chol, precomp_quants.d, + precomp_quants.QtimesV_chol, eps, nthreads, + mu1_mix, S1_mix, w1_mix); + + mu1.row(j) = trans(mu1_mix); + S1.slice(j) = S1_mix; + w1.row(j) = trans(w1_mix); + + // Compute ELBO parameters + if (compute_ELBO) + compute_ELBO_rss_terms(var_part_tr_wERSS, neg_KL, XtRbar_j, logbf_mix, mu1_mix, S1_mix, xtx_j, Vinv); + + // Compute V parameters + if (update_V) + compute_var_part_ERSS(var_part_ERSS, S1_mix, xtx_j); + + // Update the expected residuals. + XtRbar -= (Xtx * trans(mu1_mix)); + } +}