Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sparse LD matrix #5

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Encoding: UTF-8
Type: Package
Package: mr.mash.alpha
Version: 0.3.25
Date: 2024-05-06
Version: 0.3.27
Date: 2024-05-29
Title: Multiple Regression with Multivariate Adaptive Shrinkage
Description: Provides an implementation of methods for multivariate
multiple regression with adaptive shrinkage priors.
Expand All @@ -18,15 +18,15 @@ License: MIT + file LICENSE
Depends: R (>= 3.1.0)
Imports:
stats,
Matrix,
Rcpp (>= 1.0.7),
RcppParallel (>= 5.1.5),
mvtnorm,
matrixStats,
mashr (>= 0.2.73),
ebnm,
flashier (>= 1.0.7),
parallel,
Rfast
parallel
Suggests:
testthat,
varbvs,
Expand Down
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ export(mr.mash.rss)
export(predict.mr.mash)
export(predict.mr.mash.rss)
export(simulate_mr_mash_data)
importFrom(Matrix,crossprod)
importFrom(Matrix,diag)
importFrom(Matrix,isSymmetric)
importFrom(Matrix,t)
importFrom(Rcpp,evalCpp)
importFrom(RcppParallel,RcppParallelLibs)
importFrom(RcppParallel,defaultNumThreads)
importFrom(RcppParallel,setThreadOptions)
importFrom(Rfast,is.symmetric)
importFrom(ebnm,ebnm_normal)
importFrom(ebnm,ebnm_normal_scale_mixture)
importFrom(flashier,flash)
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inner_loop_general_rss_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S
.Call('_mr_mash_alpha_inner_loop_general_rss_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads)
}

inner_loop_general_rss_sparse_rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads) {
.Call('_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp', PACKAGE = 'mr.mash.alpha', n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads)
}

scale_rcpp <- function(M, a, b) {
.Call('_mr_mash_alpha_scale_rcpp', PACKAGE = 'mr.mash.alpha', M, a, b)
}
Expand Down
21 changes: 14 additions & 7 deletions R/mr_mash_rss.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @param Z p x r matrix of Z-scores from univariate
#' simple linear regression.
#'
#' @param R p x p correlation matrix among the variables.
#' @param R p x p dense or sparse correlation matrix among the variables.
#'
#' @param covY r x r covariance matrix across responses.
#'
Expand Down Expand Up @@ -171,7 +171,7 @@
#' abline(a = 0,b = 1,col = "magenta",lty = "dotted")
#'
#' @importFrom stats cov
#' @importFrom Rfast is.symmetric
#' @importFrom Matrix isSymmetric t diag crossprod
#' @importFrom RcppParallel defaultNumThreads
#' @importFrom RcppParallel setThreadOptions
#'
Expand Down Expand Up @@ -242,14 +242,14 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
if(any(is.na(Z)))
stop("Z must not contain missing values.")
}
if(!is.matrix(V) || !is.symmetric(V))
if(!is.matrix(V) || !isSymmetric(V))
stop("V must be a symmetric matrix.")
if(!missing(covY)){
if(!is.matrix(covY) || !is.symmetric(covY))
if(!is.matrix(covY) || !isSymmetric(covY))
stop("covY must be a symmetric matrix.")
}
if(!is.matrix(R) || !is.symmetric(R))
stop("R must be a symmetric matrix.")
if(!(is.matrix(R) || inherits(R,"CsparseMatrix")) || !isSymmetric(R))
stop("R must be a dense or sparse symmetric matrix.")
if(!is.list(S0))
stop("S0 must be a list.")
if(!is.vector(w0))
Expand All @@ -271,6 +271,9 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
# PRE-PROCESSING STEPS
# --------------------

###Check if R is sparse
R_is_sparse <- inherits(R,"CsparseMatrix")

###Compute Z scores
if(missing(Z)){
Z <- Bhat/Shat
Expand Down Expand Up @@ -303,6 +306,10 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
covY <- cov2cor(V)
}

if(R_is_sparse){
XtX <- as(XtX, "symmetricMatrix")
}

YtY <- covY*(n-1)

###Check whether XtX is positive semidefinite
Expand Down Expand Up @@ -490,7 +497,7 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
standardize=standardize,
update_V=update_V, version=version,
update_order=update_order, eps=eps,
nthreads=nthreads)
R_is_sparse=R_is_sparse, nthreads=nthreads)
mu1_t <- ups$mu1_t
S1_t <- ups$S1_t
w1_t <- ups$w1_t
Expand Down
25 changes: 17 additions & 8 deletions R/mr_mash_rss_updates.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,18 @@ inner_loop_general_rss_R <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0,
#'
inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads){
eps, R_is_sparse, nthreads){

out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
if(R_is_sparse){
out <- inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
} else {
out <- inner_loop_general_rss_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order,
eps, nthreads)
}


###Return output
if(compute_ELBO && update_V){
Expand All @@ -100,14 +107,14 @@ inner_loop_general_rss_Rcpp <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S
###Wrapper of the inner loop with R or Rcpp
inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, version,
update_order, eps, nthreads){
update_order, eps, R_is_sparse, nthreads){
if(version=="R"){
out <- inner_loop_general_rss_R(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps)
} else if(version=="Rcpp"){
update_order <- as.integer(update_order-1)
out <- inner_loop_general_rss_Rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, simplify2array_custom(S0), precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps, nthreads)
standardize, compute_ELBO, update_V, update_order, eps, R_is_sparse, nthreads)
}

return(out)
Expand All @@ -118,18 +125,19 @@ inner_loop_general_rss <- function(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, pr
mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV, w0, S0,
precomp_quants, compute_ELBO, standardize,
update_V, version, update_order, eps,
nthreads){
R_is_sparse, nthreads){



##Compute expected residuals
XtRbar <- XtY - XtX %*% mu1_t
XtRbar <- as.matrix(XtRbar)

##Update variational parameters, expected residuals, and ELBO components
updates <- inner_loop_general_rss(n=n, XtX=XtX, XtY=XtY, XtRbar=XtRbar, mu1=mu1_t, V=V, Vinv=Vinv, w0=w0, S0=S0,
precomp_quants=precomp_quants, standardize=standardize,
compute_ELBO=compute_ELBO, update_V=update_V, version=version,
update_order=update_order, eps=eps, nthreads=nthreads)
update_order=update_order, eps=eps, R_is_sparse=R_is_sparse, nthreads=nthreads)
mu1_t <- updates$mu1
S1_t <- updates$S1
w1_t <- updates$w1
Expand All @@ -138,6 +146,7 @@ mr_mash_update_general_rss <- function(n, XtX, XtY, YtY, mu1_t, V, Vinv, ldetV,

if(compute_ELBO || update_V){
RbartRbar <- YtY - crossprod(mu1_t, XtY) - crossprod(XtY, mu1_t) + crossprod(mu1_t, XtX)%*%mu1_t
RbartRbar <- as.matrix(RbartRbar)
}

if(compute_ELBO && update_V){
Expand Down
2 changes: 1 addition & 1 deletion man/mr.mash.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/mr.mash.rss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// inner_loop_general_rss_sparse_rcpp
List inner_loop_general_rss_sparse_rcpp(unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY, arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V, const arma::mat& Vinv, const arma::vec& w0, const arma::cube& S0, const List& precomp_quants_list, bool standardize, bool compute_ELBO, bool update_V, const arma::vec& update_order, double eps, unsigned int nthreads);
RcppExport SEXP _mr_mash_alpha_inner_loop_general_rss_sparse_rcpp(SEXP nSEXP, SEXP XtXSEXP, SEXP XtYSEXP, SEXP XtRbarSEXP, SEXP mu1SEXP, SEXP VSEXP, SEXP VinvSEXP, SEXP w0SEXP, SEXP S0SEXP, SEXP precomp_quants_listSEXP, SEXP standardizeSEXP, SEXP compute_ELBOSEXP, SEXP update_VSEXP, SEXP update_orderSEXP, SEXP epsSEXP, SEXP nthreadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< unsigned int >::type n(nSEXP);
Rcpp::traits::input_parameter< const arma::sp_mat& >::type XtX(XtXSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type XtY(XtYSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type XtRbar(XtRbarSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type mu1(mu1SEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type V(VSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Vinv(VinvSEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type w0(w0SEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type S0(S0SEXP);
Rcpp::traits::input_parameter< const List& >::type precomp_quants_list(precomp_quants_listSEXP);
Rcpp::traits::input_parameter< bool >::type standardize(standardizeSEXP);
Rcpp::traits::input_parameter< bool >::type compute_ELBO(compute_ELBOSEXP);
Rcpp::traits::input_parameter< bool >::type update_V(update_VSEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type update_order(update_orderSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
Rcpp::traits::input_parameter< unsigned int >::type nthreads(nthreadsSEXP);
rcpp_result_gen = Rcpp::wrap(inner_loop_general_rss_sparse_rcpp(n, XtX, XtY, XtRbar, mu1, V, Vinv, w0, S0, precomp_quants_list, standardize, compute_ELBO, update_V, update_order, eps, nthreads));
return rcpp_result_gen;
END_RCPP
}
// scale_rcpp
arma::mat scale_rcpp(const arma::mat& M, const arma::vec& a, const arma::vec& b);
RcppExport SEXP _mr_mash_alpha_scale_rcpp(SEXP MSEXP, SEXP aSEXP, SEXP bSEXP) {
Expand Down Expand Up @@ -160,6 +186,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_mr_mash_alpha_inner_loop_general_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rcpp, 14},
{"_mr_mash_alpha_impute_missing_Y_rcpp", (DL_FUNC) &_mr_mash_alpha_impute_missing_Y_rcpp, 5},
{"_mr_mash_alpha_inner_loop_general_rss_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_rcpp, 16},
{"_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp", (DL_FUNC) &_mr_mash_alpha_inner_loop_general_rss_sparse_rcpp, 16},
{"_mr_mash_alpha_scale_rcpp", (DL_FUNC) &_mr_mash_alpha_scale_rcpp, 3},
{"_mr_mash_alpha_scale2_rcpp", (DL_FUNC) &_mr_mash_alpha_scale2_rcpp, 3},
{"_mr_mash_alpha_rescale_post_mean_covar_rcpp", (DL_FUNC) &_mr_mash_alpha_rescale_post_mean_covar_rcpp, 3},
Expand Down
123 changes: 123 additions & 0 deletions src/mr_mash_updates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS);

// Inner loop rss sprse
void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar,
mat& mu1, const mat& V, const mat& Vinv, const vec& w0,
const cube& S0, const mr_mash_precomputed_quantities& precomp_quants,
bool standardize, bool compute_ELBO, bool update_V,
const vec& update_order, double eps, unsigned int nthreads,
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS);


// FUNCTION DEFINITIONS
// --------------------
Expand Down Expand Up @@ -357,3 +366,117 @@ void inner_loop_general_rss (unsigned int n, const mat& XtX, const mat& XtY, mat
XtRbar -= (Xtx * trans(mu1_mix));
}
}


// Inner loop rss sparse
//
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppParallel)]]
// [[Rcpp::export]]
List inner_loop_general_rss_sparse_rcpp (unsigned int n, const arma::sp_mat& XtX, const arma::mat& XtY,
arma::mat& XtRbar, arma::mat& mu1, const arma::mat& V,
const arma::mat& Vinv, const arma::vec& w0,
const arma::cube& S0, const List& precomp_quants_list,
bool standardize, bool compute_ELBO, bool update_V,
const arma::vec& update_order, double eps, unsigned int nthreads) {
unsigned int r = mu1.n_cols;
unsigned int p = mu1.n_rows;
unsigned int k = w0.n_elem;
cube S1(r,r,p);
mat w1(p,k);
mat mu1_new = mu1;
double var_part_tr_wERSS;
double neg_KL;
mat var_part_ERSS(r,r);
mr_mash_precomputed_quantities precomp_quants
(as<mat>(precomp_quants_list["S"]),
as<mat>(precomp_quants_list["S_chol"]),
as<cube>(precomp_quants_list["S1"]),
as<cube>(precomp_quants_list["SplusS0_chol"]),
as<mat>(precomp_quants_list["V_chol"]),
as<mat>(precomp_quants_list["d"]),
as<cube>(precomp_quants_list["QtimesV_chol"]),
as<vec>(precomp_quants_list["xtx"]));
inner_loop_general_rss_sparse(n, XtX, XtY, XtRbar, mu1_new, V, Vinv, w0, S0, precomp_quants,
standardize, compute_ELBO, update_V, update_order, eps,
nthreads, S1, w1, var_part_tr_wERSS, neg_KL, var_part_ERSS);
return List::create(Named("mu1") = mu1_new,
Named("S1") = S1,
Named("w1") = w1,
Named("var_part_tr_wERSS") = var_part_tr_wERSS,
Named("neg_KL") = neg_KL,
Named("var_part_ERSS") = var_part_ERSS);
}

// Perform the inner loop rss
void inner_loop_general_rss_sparse (unsigned int n, const sp_mat& XtX, const mat& XtY, mat& XtRbar,
mat& mu1, const mat& V, const mat& Vinv, const vec& w0,
const cube& S0, const mr_mash_precomputed_quantities& precomp_quants,
bool standardize, bool compute_ELBO, bool update_V,
const vec& update_order, double eps, unsigned int nthreads,
cube& S1, mat& w1, double& var_part_tr_wERSS,
double& neg_KL, mat& var_part_ERSS) {
unsigned int p = mu1.n_rows;
unsigned int r = mu1.n_cols;
unsigned int k = w0.n_elem;
vec Xtx(p);
vec XtRbar_j(r);
vec mu1_j(r);
vec mu1_mix(r);
mat S1_mix(r,r);
vec w1_mix(k);
double logbf_mix;
double xtx_j;

// Initialize ELBO parameters
var_part_tr_wERSS = 0;
neg_KL = 0;

// Initialize V parameters
var_part_ERSS.zeros(r,r);

// Repeat for each predictor.
for (unsigned int j : update_order) {

if (standardize)
xtx_j = n-1;
else
xtx_j = precomp_quants.xtx(j);

Xtx = XtX.col(j);
mu1_j = trans(mu1.row(j));

// Disregard the jth predictor in the expected residuals.
XtRbar += (Xtx * trans(mu1_j));
XtRbar_j = trans(XtRbar.row(j));

// Update the posterior quantities for the jth
// predictor.
if (standardize)
logbf_mix = bayes_mvr_mix_standardized_X_rss(n, XtRbar_j, w0, S0, precomp_quants.S,
precomp_quants.S1,
precomp_quants.SplusS0_chol,
precomp_quants.S_chol, eps, nthreads,
mu1_mix, S1_mix, w1_mix);
else
logbf_mix = bayes_mvr_mix_centered_X_rss(XtRbar_j, V, w0, S0, xtx_j, Vinv,
precomp_quants.V_chol, precomp_quants.d,
precomp_quants.QtimesV_chol, eps, nthreads,
mu1_mix, S1_mix, w1_mix);

mu1.row(j) = trans(mu1_mix);
S1.slice(j) = S1_mix;
w1.row(j) = trans(w1_mix);

// Compute ELBO parameters
if (compute_ELBO)
compute_ELBO_rss_terms(var_part_tr_wERSS, neg_KL, XtRbar_j, logbf_mix, mu1_mix, S1_mix, xtx_j, Vinv);

// Compute V parameters
if (update_V)
compute_var_part_ERSS(var_part_ERSS, S1_mix, xtx_j);

// Update the expected residuals.
XtRbar -= (Xtx * trans(mu1_mix));
}
}
Loading