Skip to content

Commit

Permalink
Merge pull request #5 from morgantelab/master
Browse files Browse the repository at this point in the history
Support sparse LD matrix.
  • Loading branch information
pcarbo authored Jun 3, 2024
2 parents 9aab396 + 4cabe70 commit 9e2ec99
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 23 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Encoding: UTF-8
Type: Package
Package: mr.mash.alpha
Version: 0.3.25
Date: 2024-05-06
Version: 0.3.27
Date: 2024-05-29
Title: Multiple Regression with Multivariate Adaptive Shrinkage
Description: Provides an implementation of methods for multivariate
multiple regression with adaptive shrinkage priors.
Expand All @@ -18,15 +18,15 @@ License: MIT + file LICENSE
Depends: R (>= 3.1.0)
Imports:
stats,
Matrix,
Rcpp (>= 1.0.7),
RcppParallel (>= 5.1.5),
mvtnorm,
matrixStats,
mashr (>= 0.2.73),
ebnm,
flashier (>= 1.0.7),
parallel,
Rfast
parallel
Suggests:
testthat,
varbvs,
Expand Down
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ export(mr.mash.rss)
export(predict.mr.mash)
export(predict.mr.mash.rss)
export(simulate_mr_mash_data)
importFrom(Matrix,crossprod)
importFrom(Matrix,diag)
importFrom(Matrix,isSymmetric)
importFrom(Matrix,t)
importFrom(Rcpp,evalCpp)
importFrom(RcppParallel,RcppParallelLibs)
importFrom(RcppParallel,defaultNumThreads)
importFrom(RcppParallel,setThreadOptions)
importFrom(Rfast,is.symmetric)
importFrom(ebnm,ebnm_normal)
importFrom(ebnm,ebnm_normal_scale_mixture)
importFrom(flashier,flash)
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inner_loop_general_rss_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S
.Call('_mr_mash_alpha_inner_loop_general_rss_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads)
}

inner_loop_general_rss_sparse_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads) {
.Call('_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads)
}

scale_rcpp <- function(M, a, b) {
.Call('_mr_mash_alpha_scale_rcpp', PACKAGE = 'mr.mash.alpha', M, a, b)
}
Expand Down
21 changes: 14 additions & 7 deletions R/mr_mash_rss.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @param Z p x r matrix of Z-scores from univariate
#' simple linear regression.
#'
#' @param R p x p correlation matrix among the variables.
#' @param R p x p dense or sparse correlation matrix among the variables.
#'
#' @param covY r x r covariance matrix across responses.
#'
Expand Down Expand Up @@ -171,7 +171,7 @@
#' abline(a = 0,b = 1,col = "magenta",lty = "dotted")
#'
#' @importFrom stats cov
#' @importFrom Rfast is.symmetric
#' @importFrom Matrix isSymmetric t diag crossprod
#' @importFrom RcppParallel defaultNumThreads
#' @importFrom RcppParallel setThreadOptions
#'
Expand Down Expand Up @@ -242,14 +242,14 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
if(any(is.na(Z)))
stop("Z must not contain missing values.")
}
if(!is.matrix(V) || !is.symmetric(V))
if(!is.matrix(V) || !isSymmetric(V))
stop("V must be a symmetric matrix.")
if(!missing(covY)){
if(!is.matrix(covY) || !is.symmetric(covY))
if(!is.matrix(covY) || !isSymmetric(covY))
stop("covY must be a symmetric matrix.")
}
if(!is.matrix(R) || !is.symmetric(R))
stop("R must be a symmetric matrix.")
if(!(is.matrix(R) || inherits(R,"CsparseMatrix")) || !isSymmetric(R))
stop("R must be a dense or sparse symmetric matrix.")
if(!is.list(S0))
stop("S0 must be a list.")
if(!is.vector(w0))
Expand All @@ -271,6 +271,9 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
# PRE-PROCESSING STEPS
# --------------------

###Check if R is sparse
R_is_sparse <- inherits(R,"CsparseMatrix")

###Compute Z scores
if(missing(Z)){
Z <- Bhat/Shat
Expand Down Expand Up @@ -303,6 +306,10 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
covY <- cov2cor(V)
}

if(R_is_sparse){
XtX <- as(XtX, "symmetricMatrix")
}

YtY <- covY*(n-1)

###Check whether XtX is positive semidefinite
Expand Down Expand Up @@ -490,7 +497,7 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
standardize=standardize,
update_V=update_V, version=version,
update_order=update_order, eps=eps,
nthreads=nthreads)
R_is_sparse=R_is_sparse, nthreads=nthreads)
mu1_t <- ups$mu1_t
S1_t <- ups$S1_t
w1_t <- ups$w1_t
Expand Down
25 changes: 17 additions & 8 deletions R/mr_mash_rss_updates.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,18 @@ inner_loop_general_rss_R <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0,
#'
inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads){
eps, R_is_sparse, nthreads){

out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
if(R_is_sparse){
out <- inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
} else {
out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
}


###Return output
if(compute_ELBO && update_V){
Expand All @@ -100,14 +107,14 @@ inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S
###Wrapper of the inner loop with R or Rcpp
inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, version,
update_order, eps, nthreads){
update_order, eps, R_is_sparse, nthreads){
if(version=="R"){
out <- inner_loop_general_rss_R(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps)
} else if(version=="Rcpp"){
update_order <- as.integer(update_order-1)
out <- inner_loop_general_rss_Rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, simplify2array_custom(S0), precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps, nthreads)
standardize, compute_ELBO, update_V, update_order, eps, R_is_sparse, nthreads)
}

return(out)
Expand All @@ -118,18 +125,19 @@ inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, pr
mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV, w0, S0,
precomp_quants, compute_ELBO, standardize,
update_V, version, update_order, eps,
nthreads){
R_is_sparse, nthreads){



##Compute expected residuals
XtRbar <- XtY - XtX %*% mu1_t
XtRbar <- as.matrix(XtRbar)

##Update variational parameters, expected residuals, and ELBO components
updates <- inner_loop_general_rss(n=n, XtX=XtX, XtY=XtY, XtRbar=XtRbar, mu1=mu1_t, V=V, Vinv=Vinv, w0=w0, S0=S0,
precomp_quants=precomp_quants, standardize=standardize,
compute_ELBO=compute_ELBO, update_V=update_V, version=version,
update_order=update_order, eps=eps, nthreads=nthreads)
update_order=update_order, eps=eps, R_is_sparse=R_is_sparse, nthreads=nthreads)
mu1_t <- updates$mu1
S1_t <- updates$S1
w1_t <- updates$w1
Expand All @@ -138,6 +146,7 @@ mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV,

if(compute_ELBO || update_V){
RbartRbar <- YtY - crossprod(mu1_t, XtY) - crossprod(XtY, mu1_t) + crossprod(mu1_t, XtX)%*%mu1_t
RbartRbar <- as.matrix(RbartRbar)
}

if(compute_ELBO && update_V){
Expand Down
2 changes: 1 addition & 1 deletion man/mr.mash.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/mr.mash.rss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// inner_loop_general_rss_sparse_rcpp
List inner_loop_general_rss_sparse_rcpp(unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY, arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V, const arma::mat& Vinv, const arma::vec& w0, const arma::cube& S0, const List& precomp_quants_list, bool standardize, bool compute_ELBO, bool update_V, const arma::vec& update_order, double eps, unsigned int nthreads);
RcppExport SEXP _mr_mash_alpha_inner_loop_general_rss_sparse_rcpp(SEXP nSEXP, SEXP XtXSEXP, SEXP XtYSEXP, SEXP XtRbarSEXP, SEXP mu1SEXP, SEXP VSEXP, SEXP VinvSEXP, SEXP w0SEXP, SEXP S0SEXP, SEXP precomp_quants_listSEXP, SEXP standardizeSEXP, SEXP compute_ELBOSEXP, SEXP update_VSEXP, SEXP update_orderSEXP, SEXP epsSEXP, SEXP nthreadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< unsigned int >::type n(nSEXP);
Rcpp::traits::input_parameter< const arma::sp_mat& >::type XtX(XtXSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type XtY(XtYSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type XtRbar(XtRbarSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type mu1(mu1SEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type V(VSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Vinv(VinvSEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type w0(w0SEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type S0(S0SEXP);
Rcpp::traits::input_parameter< const List& >::type precomp_quants_list(precomp_quants_listSEXP);
Rcpp::traits::input_parameter< bool >::type standardize(standardizeSEXP);
Rcpp::traits::input_parameter< bool >::type compute_ELBO(compute_ELBOSEXP);
Rcpp::traits::input_parameter< bool >::type update_V(update_VSEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type update_order(update_orderSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
Rcpp::traits::input_parameter< unsigned int >::type nthreads(nthreadsSEXP);
rcpp_result_gen = Rcpp::wrap(inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads));
return rcpp_result_gen;
END_RCPP
}
// scale_rcpp
arma::mat scale_rcpp(const arma::mat& M, const arma::vec& a, const arma::vec& b);
RcppExport SEXP _mr_mash_alpha_scale_rcpp(SEXP MSEXP, SEXP aSEXP, SEXP bSEXP) {
Expand Down Expand Up @@ -160,6 +186,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_mr_mash_alpha_inner_loop_general_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rcpp, 14},
{"_mr_mash_alpha_impute_missing_Y_rcpp", (DL_FUNC) &_mr_mash_alpha_impute_missing_Y_rcpp, 5},
{"_mr_mash_alpha_inner_loop_general_rss_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_rcpp, 16},
{"_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp, 16},
{"_mr_mash_alpha_scale_rcpp", (DL_FUNC) &_mr_mash_alpha_scale_rcpp, 3},
{"_mr_mash_alpha_scale2_rcpp", (DL_FUNC) &_mr_mash_alpha_scale2_rcpp, 3},
{"_mr_mash_alpha_rescale_post_mean_covar_rcpp", (DL_FUNC) &_mr_mash_alpha_rescale_post_mean_covar_rcpp, 3},
Expand Down
123 changes: 123 additions & 0 deletions src/mr_mash_updates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS);

// Inner loop rss sprse
void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar,
mat& mu1, const mat& V, const mat& Vinv, const vec& w0,
const cube& S0, const mr_mash_precomputed_quantities& precomp_quants,
bool standardize, bool compute_ELBO, bool update_V,
const vec& update_order, double eps, unsigned int nthreads,
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS);


// FUNCTION DEFINITIONS
// --------------------
Expand Down Expand Up @@ -357,3 +366,117 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat
XtRbar -= (Xtx * trans(mu1_mix));
}
}


// Inner loop rss sparse
//
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppParallel)]]
// [[Rcpp::export]]
List inner_loop_general_rss_sparse_rcpp (unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY,
arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V,
const arma::mat& Vinv, const arma::vec& w0,
const arma::cube& S0, const List& precomp_quants_list,
bool standardize, bool compute_ELBO, bool update_V,
const arma::vec& update_order, double eps, unsigned int nthreads) {
unsigned int r = mu1.n_cols;
unsigned int p = mu1.n_rows;
unsigned int k = w0.n_elem;
cube S1(r,r,p);
mat w1(p,k);
mat mu1_new = mu1;
double var_part_tr_wERSS;
double neg_KL;
mat var_part_ERSS(r,r);
mr_mash_precomputed_quantities precomp_quants
(as<mat>(precomp_quants_list["S"]),
as<mat>(precomp_quants_list["S_chol"]),
as<cube>(precomp_quants_list["S1"]),
as<cube>(precomp_quants_list["SplusS0_chol"]),
as<mat>(precomp_quants_list["V_chol"]),
as<mat>(precomp_quants_list["d"]),
as<cube>(precomp_quants_list["QtimesV_chol"]),
as<vec>(precomp_quants_list["xtx"]));
inner_loop_general_rss_sparse(n, XtX, XtY, XtRbar, mu1_new, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps,
nthreads, S1, w1, var_part_tr_wERSS, neg_KL, var_part_ERSS);
return List::create(Named("mu1") = mu1_new,
Named("S1") = S1,
Named("w1") = w1,
Named("var_part_tr_wERSS") = var_part_tr_wERSS,
Named("neg_KL") = neg_KL,
Named("var_part_ERSS") = var_part_ERSS);
}

// Perform the inner loop rss
void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar,
mat& mu1, const mat& V, const mat& Vinv, const vec& w0,
const cube& S0, const mr_mash_precomputed_quantities& precomp_quants,
bool standardize, bool compute_ELBO, bool update_V,
const vec& update_order, double eps, unsigned int nthreads,
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS) {
unsigned int p = mu1.n_rows;
unsigned int r = mu1.n_cols;
unsigned int k = w0.n_elem;
vec Xtx(p);
vec XtRbar_j(r);
vec mu1_j(r);
vec mu1_mix(r);
mat S1_mix(r,r);
vec w1_mix(k);
double logbf_mix;
double xtx_j;

// Initialize ELBO parameters
var_part_tr_wERSS = 0;
neg_KL = 0;

// Initialize V parameters
var_part_ERSS.zeros(r,r);

// Repeat for each predictor.
for (unsigned int j : update_order) {

if (standardize)
xtx_j = n-1;
else
xtx_j = precomp_quants.xtx(j);

Xtx = XtX.col(j);
mu1_j = trans(mu1.row(j));

// Disregard the jth predictor in the expected residuals.
XtRbar += (Xtx * trans(mu1_j));
XtRbar_j = trans(XtRbar.row(j));

// Update the posterior quantities for the jth
// predictor.
if (standardize)
logbf_mix = bayes_mvr_mix_standardized_X_rss(n, XtRbar_j, w0, S0, precomp_quants.S,
precomp_quants.S1,
precomp_quants.SplusS0_chol,
precomp_quants.S_chol, eps, nthreads,
mu1_mix, S1_mix, w1_mix);
else
logbf_mix = bayes_mvr_mix_centered_X_rss(XtRbar_j, V, w0, S0, xtx_j, Vinv,
precomp_quants.V_chol, precomp_quants.d,
precomp_quants.QtimesV_chol, eps, nthreads,
mu1_mix, S1_mix, w1_mix);

mu1.row(j) = trans(mu1_mix);
S1.slice(j) = S1_mix;
w1.row(j) = trans(w1_mix);

// Compute ELBO parameters
if (compute_ELBO)
compute_ELBO_rss_terms(var_part_tr_wERSS, neg_KL, XtRbar_j, logbf_mix, mu1_mix, S1_mix, xtx_j, Vinv);

// Compute V parameters
if (update_V)
compute_var_part_ERSS(var_part_ERSS, S1_mix, xtx_j);

// Update the expected residuals.
XtRbar -= (Xtx * trans(mu1_mix));
}
}

0 comments on commit 9e2ec99

Please sign in to comment.