Skip to content

Commit

Permalink
Merge pull request #2 from willwerscheid/master
Browse files Browse the repository at this point in the history
Update compute_cov_flash() to work with latest flashier version.
  • Loading branch information
pcarbo authored Aug 17, 2023
2 parents e6f3cd8 + 5c9b754 commit 468579b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 59 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Encoding: UTF-8
Type: Package
Package: mr.mash.alpha
Version: 0.2-27
Date: 2023-05-19
Version: 0.2-28
Date: 2023-08-17
Title: Multiple Regression with Multivariate Adaptive Shrinkage
Description: Provides an implementation of methods for multivariate
multiple regression with adaptive shrinkage priors.
Expand All @@ -22,7 +22,7 @@ Imports:
matrixStats,
mashr (>= 0.2.41),
ebnm,
flashier,
flashier (>= 0.2.50),
parallel
Suggests:
testthat,
Expand Down
112 changes: 56 additions & 56 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ removefromcols <- function (A, b)
###Function to simulate from MN distribution
#
#' @importFrom MBSP matrix_normal
#'
#'
sim_mvr <- function (X, B, V) {

# Get the number of samples (n) and conditions (m).
n <- nrow(X)
r <- ncol(B)

# Simulate the responses, Y.
M <- X%*%B
U <- diag(n)
Y <- matrix_normal(M, U, V)

# Output the simulated responses.
return(Y)
}
Expand All @@ -67,29 +67,29 @@ simplify2array_custom <- function (x, higher = TRUE) {
if (common.len >= 1L){
n <- length(x)
r <- unlist(x, recursive = FALSE, use.names = FALSE)
if (higher && length(c.dim <- unique(lapply(x, dim))) ==
if (higher && length(c.dim <- unique(lapply(x, dim))) ==
1 && is.numeric(c.dim <- c.dim[[1L]]) && prod(d <- c(c.dim, n)) == length(r)) {
iN1 <- is.null(n1 <- dimnames(x[[1L]]))
n2 <- names(x)
dnam <- if (!(iN1 && is.null(n2)))
c(if (iN1) rep.int(list(n1), length(c.dim)) else n1,
dnam <- if (!(iN1 && is.null(n2)))
c(if (iN1) rep.int(list(n1), length(c.dim)) else n1,
list(n2))
array(r, dim = d, dimnames = dnam)
}
else if (prod(d <- c(common.len, n)) == length(r))
array(r, dim = d, dimnames = if (!(is.null(n1 <- names(x[[1L]])) &
is.null(n2 <- names(x))))
else if (prod(d <- c(common.len, n)) == length(r))
array(r, dim = d, dimnames = if (!(is.null(n1 <- names(x[[1L]])) &
is.null(n2 <- names(x))))
list(n1, n2))
else x
} else {
x
}
}

###Add small number e to diagonal elements to a matrix
###Add small number e to diagonal elements to a matrix
makePD <- function(S0, e){
S0_PD <- S0+(diag(nrow(S0))*e)

return(S0_PD)
}

Expand All @@ -98,58 +98,58 @@ precompute_quants <- function(X, V, S0, standardize, version){
if(standardize){
n <- nrow(X)
xtx <- n-1

###Quantities that don't depend on S0
R <- chol(V)
S <- V/xtx
S_chol <- R/sqrt(xtx)

###Quantities that depend on S0
SplusS0_chol <- list()
S1 <- list()
for(i in 1:length(S0)){
SplusS0_chol[[i]] <- chol(S+S0[[i]])
S1[[i]] <- S0[[i]]%*%backsolve(SplusS0_chol[[i]], forwardsolve(t(SplusS0_chol[[i]]), S))
}

if(version=="R"){
return(list(V_chol=R, S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
return(list(V_chol=R, S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
} else if(version=="Rcpp"){
xtx <- c(0, 0) ##Vector
d <- matrix(0, nrow=1, ncol=1)
QtimesR <- array(0, c(1, 1, 1))
return(list(V_chol=R, S=S, S1=simplify2array_custom(S1), S_chol=S_chol, SplusS0_chol=simplify2array_custom(SplusS0_chol),

return(list(V_chol=R, S=S, S1=simplify2array_custom(S1), S_chol=S_chol, SplusS0_chol=simplify2array_custom(SplusS0_chol),
xtx=xtx, d=d, QtimesV_chol=QtimesR))
}

} else {
###Quantities that don't depend on S0
R <- chol(V)
#Rtinv <- solve(t(R))
#Rinv <- solve(R)
Rtinv <- forwardsolve(t(R), diag(nrow(R)))
Rinv <- backsolve(R, diag(nrow(R)))

###Quantities that depend on S0
d <- list()
QtimesR <- list()
for(i in 1:length(S0)){
U0 <- Rtinv %*% S0[[i]] %*% Rinv
out <- eigen(U0)
d[[i]] <- out$values
QtimesR[[i]] <- crossprod(out$vectors, R)
QtimesR[[i]] <- crossprod(out$vectors, R)
}

if(version=="R"){
return(list(V_chol=R, d=d, QtimesV_chol=QtimesR))
} else if(version=="Rcpp"){
S <- matrix(0, nrow=1, ncol=1)
S1 <- array(0, c(1, 1, 1))
S_chol <- matrix(0, nrow=1, ncol=1)
SplusS0_chol <- array(0, c(1, 1, 1))
return(list(V_chol=R, d=simplify2array_custom(d), QtimesV_chol=simplify2array_custom(QtimesR),

return(list(V_chol=R, d=simplify2array_custom(d), QtimesV_chol=simplify2array_custom(QtimesR),
S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
}
}
Expand All @@ -174,95 +174,95 @@ filter_precomputed_quants <- function(precomp_quants, to_keep, standardize, vers
precomp_quants$QtimesV_chol <- precomp_quants$QtimesV_chol[, , to_keep]
}
}

return(precomp_quants)
}

###Compute variance part of the ERSS
compute_var_part_ERSS <- function(var_part_ERSS, bfit, xtx){
var_part_ERSS <- var_part_ERSS + (bfit$S1*xtx)

return(var_part_ERSS)
}

###Rescale posterior mean and covariance of the regression coefficients when standardizing X
rescale_post_mean_covar <- function(mu1, S1, sx){
p <- nrow(mu1)
r <- ncol(mu1)

mu1_orig <- mu1/sx

S1_orig <- array(0, c(r, r, p))
for(j in 1:p){
S1_orig[, , j] <- S1[, , j]/sx[j]^2
}

return(list(mu1_orig=mu1_orig, S1_orig=S1_orig))
}

###Faster version of rescale_post_mean_covar()
rescale_post_mean_covar_fast <- function(mu1, S1, sx){
rescale_post_mean_covar_rcpp(mu1, S1, sx)
}


###Scale a matrix (similar to but faster than base::scale())
#' @importFrom matrixStats colSds colMeans2
#'
#'
scale_fast <- function(M, scale=TRUE, na.rm=TRUE){
##Check whether M is a matrix. If not, coerce into it.
##Check whether M is a matrix. If not, coerce into it.
if(!is.matrix(M))
M <- as.matrix(M)

##Store dimnames
col_names <- colnames(M)
row_names <- rownames(M)

###Compute column means and sds
a <- colMeans2(M, na.rm=na.rm)
names(a) <- col_names
if(scale){
if(scale){
b <- colSds(M, na.rm=na.rm)
if(any(b==0))
stop("Some column(s) have 0 standard deviation")
} else{
b <- rep(1, ncol(M))
}
names(b) <- col_names

###Scale
M <- scale_rcpp(M, a, b)

###Attach dimension names
colnames(M) <- col_names
rownames(M) <- row_names

return(list(M=M, means=a, sds=b))
}

###Scale a matrix (similar to the above but does not use R to compute means and sds)
scale_fast2 <- function(M, scale=TRUE, na.rm=TRUE){
##Check whether M is a matrix. If not, coerce into it.
##Check whether M is a matrix. If not, coerce into it.
if(!is.matrix(M))
M <- as.matrix(M)

##Store dimnames
col_names <- colnames(M)
row_names <- rownames(M)

###Scale
out <- scale2_rcpp(M, scale=scale, na_rm=na.rm)
means <- drop(out$means)
sds <- drop(out$sds)
M <- out$M
rm(out)

###Attach dimension names
colnames(M) <- col_names
rownames(M) <- row_names
names(means) <- col_names
names(sds) <- col_names

return(list(M=M, means=means, sds=sds))
}

Expand All @@ -273,18 +273,18 @@ scale_fast2 <- function(M, scale=TRUE, na.rm=TRUE){
#' @importFrom ebnm ebnm_normal
#' @importFrom ebnm ebnm_normal_scale_mixture
#' @importFrom flashier flash
#'
#'
compute_cov_flash <- function(Y, error_cache = NULL){
covar <- diag(ncol(Y))
tryCatch({
fl <- flash(Y, var.type = 2,
ebnm.fn = c(ebnm_normal,ebnm_normal_scale_mixture),
fl <- flash(Y, var_type = 2,
ebnm_fn = c(ebnm_normal,ebnm_normal_scale_mixture),
backfit = TRUE,verbose = 0)
if (fl$n.factors == 0) {
covar <- diag(fl$residuals.sd^2)
if (fl$n_factors == 0) {
covar <- diag(fl$residuals_sd^2)
} else {
fsd <- sapply(fl$L.ghat,"[[","sd")
covar <- diag(fl$residuals.sd^2) + crossprod(t(fl$F.pm) * fsd)
fsd <- sapply(fl$L_ghat,"[[","sd")
covar <- diag(fl$residuals_sd^2) + crossprod(t(fl$F_pm) * fsd)
}
if (nrow(covar) == 0) {
covar <- diag(ncol(Y))
Expand All @@ -311,31 +311,31 @@ compute_cov_flash <- function(Y, error_cache = NULL){
#' @importFrom stats cov
compute_V_init <- function(X, Y, B, intercept, method=c("cov", "flash")){
method <- match.arg(method)

R <- removefromcols(Y, intercept) - X%*%B

if(method=="cov")
V <- cov(R)
else if(method=="flash")
V <- compute_cov_flash(R)

return(V)
}

###Extract Y missingness patterns for each individual
extract_missing_Y_pattern <- function(Y){

n <- nrow(Y)
miss <- vector("list", n)
non_miss <- vector("list", n)

for(i in 1:n){
miss_i <- is.na(Y[i, ])
non_miss_i <- !miss_i
miss[[i]] <- miss_i
non_miss[[i]] <- non_miss_i
}

return(list(miss=miss, non_miss=non_miss))
}

0 comments on commit 468579b

Please sign in to comment.