Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update compute_cov_flash() to work with latest flashier version #2

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Encoding: UTF-8
Type: Package
Package: mr.mash.alpha
Version: 0.2-27
Date: 2023-05-19
Version: 0.2-28
Date: 2023-08-17
Title: Multiple Regression with Multivariate Adaptive Shrinkage
Description: Provides an implementation of methods for multivariate
multiple regression with adaptive shrinkage priors.
Expand All @@ -22,7 +22,7 @@ Imports:
matrixStats,
mashr (>= 0.2.41),
ebnm,
flashier,
flashier (>= 0.2.50),
parallel
Suggests:
testthat,
Expand Down
112 changes: 56 additions & 56 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ removefromcols <- function (A, b)
###Function to simulate from MN distribution
#
#' @importFrom MBSP matrix_normal
#'
#'
sim_mvr <- function (X, B, V) {

# Get the number of samples (n) and conditions (m).
n <- nrow(X)
r <- ncol(B)

# Simulate the responses, Y.
M <- X%*%B
U <- diag(n)
Y <- matrix_normal(M, U, V)

# Output the simulated responses.
return(Y)
}
Expand All @@ -67,29 +67,29 @@ simplify2array_custom <- function (x, higher = TRUE) {
if (common.len >= 1L){
n <- length(x)
r <- unlist(x, recursive = FALSE, use.names = FALSE)
if (higher && length(c.dim <- unique(lapply(x, dim))) ==
if (higher && length(c.dim <- unique(lapply(x, dim))) ==
1 && is.numeric(c.dim <- c.dim[[1L]]) && prod(d <- c(c.dim, n)) == length(r)) {
iN1 <- is.null(n1 <- dimnames(x[[1L]]))
n2 <- names(x)
dnam <- if (!(iN1 && is.null(n2)))
c(if (iN1) rep.int(list(n1), length(c.dim)) else n1,
dnam <- if (!(iN1 && is.null(n2)))
c(if (iN1) rep.int(list(n1), length(c.dim)) else n1,
list(n2))
array(r, dim = d, dimnames = dnam)
}
else if (prod(d <- c(common.len, n)) == length(r))
array(r, dim = d, dimnames = if (!(is.null(n1 <- names(x[[1L]])) &
is.null(n2 <- names(x))))
else if (prod(d <- c(common.len, n)) == length(r))
array(r, dim = d, dimnames = if (!(is.null(n1 <- names(x[[1L]])) &
is.null(n2 <- names(x))))
list(n1, n2))
else x
} else {
x
}
}

###Add small number e to diagonal elements to a matrix
###Add small number e to diagonal elements to a matrix
makePD <- function(S0, e){
S0_PD <- S0+(diag(nrow(S0))*e)

return(S0_PD)
}

Expand All @@ -98,58 +98,58 @@ precompute_quants <- function(X, V, S0, standardize, version){
if(standardize){
n <- nrow(X)
xtx <- n-1

###Quantities that don't depend on S0
R <- chol(V)
S <- V/xtx
S_chol <- R/sqrt(xtx)

###Quantities that depend on S0
SplusS0_chol <- list()
S1 <- list()
for(i in 1:length(S0)){
SplusS0_chol[[i]] <- chol(S+S0[[i]])
S1[[i]] <- S0[[i]]%*%backsolve(SplusS0_chol[[i]], forwardsolve(t(SplusS0_chol[[i]]), S))
}

if(version=="R"){
return(list(V_chol=R, S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
return(list(V_chol=R, S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
} else if(version=="Rcpp"){
xtx <- c(0, 0) ##Vector
d <- matrix(0, nrow=1, ncol=1)
QtimesR <- array(0, c(1, 1, 1))
return(list(V_chol=R, S=S, S1=simplify2array_custom(S1), S_chol=S_chol, SplusS0_chol=simplify2array_custom(SplusS0_chol),

return(list(V_chol=R, S=S, S1=simplify2array_custom(S1), S_chol=S_chol, SplusS0_chol=simplify2array_custom(SplusS0_chol),
xtx=xtx, d=d, QtimesV_chol=QtimesR))
}

} else {
###Quantities that don't depend on S0
R <- chol(V)
#Rtinv <- solve(t(R))
#Rinv <- solve(R)
Rtinv <- forwardsolve(t(R), diag(nrow(R)))
Rinv <- backsolve(R, diag(nrow(R)))

###Quantities that depend on S0
d <- list()
QtimesR <- list()
for(i in 1:length(S0)){
U0 <- Rtinv %*% S0[[i]] %*% Rinv
out <- eigen(U0)
d[[i]] <- out$values
QtimesR[[i]] <- crossprod(out$vectors, R)
QtimesR[[i]] <- crossprod(out$vectors, R)
}

if(version=="R"){
return(list(V_chol=R, d=d, QtimesV_chol=QtimesR))
} else if(version=="Rcpp"){
S <- matrix(0, nrow=1, ncol=1)
S1 <- array(0, c(1, 1, 1))
S_chol <- matrix(0, nrow=1, ncol=1)
SplusS0_chol <- array(0, c(1, 1, 1))
return(list(V_chol=R, d=simplify2array_custom(d), QtimesV_chol=simplify2array_custom(QtimesR),

return(list(V_chol=R, d=simplify2array_custom(d), QtimesV_chol=simplify2array_custom(QtimesR),
S=S, S1=S1, S_chol=S_chol, SplusS0_chol=SplusS0_chol))
}
}
Expand All @@ -174,95 +174,95 @@ filter_precomputed_quants <- function(precomp_quants, to_keep, standardize, vers
precomp_quants$QtimesV_chol <- precomp_quants$QtimesV_chol[, , to_keep]
}
}

return(precomp_quants)
}

###Compute variance part of the ERSS
compute_var_part_ERSS <- function(var_part_ERSS, bfit, xtx){
var_part_ERSS <- var_part_ERSS + (bfit$S1*xtx)

return(var_part_ERSS)
}

###Rescale posterior mean and covariance of the regression coefficients when standardizing X
rescale_post_mean_covar <- function(mu1, S1, sx){
p <- nrow(mu1)
r <- ncol(mu1)

mu1_orig <- mu1/sx

S1_orig <- array(0, c(r, r, p))
for(j in 1:p){
S1_orig[, , j] <- S1[, , j]/sx[j]^2
}

return(list(mu1_orig=mu1_orig, S1_orig=S1_orig))
}

###Faster version of rescale_post_mean_covar()
rescale_post_mean_covar_fast <- function(mu1, S1, sx){
rescale_post_mean_covar_rcpp(mu1, S1, sx)
}


###Scale a matrix (similar to but faster than base::scale())
#' @importFrom matrixStats colSds colMeans2
#'
#'
scale_fast <- function(M, scale=TRUE, na.rm=TRUE){
##Check whether M is a matrix. If not, coerce into it.
##Check whether M is a matrix. If not, coerce into it.
if(!is.matrix(M))
M <- as.matrix(M)

##Store dimnames
col_names <- colnames(M)
row_names <- rownames(M)

###Compute column means and sds
a <- colMeans2(M, na.rm=na.rm)
names(a) <- col_names
if(scale){
if(scale){
b <- colSds(M, na.rm=na.rm)
if(any(b==0))
stop("Some column(s) have 0 standard deviation")
} else{
b <- rep(1, ncol(M))
}
names(b) <- col_names

###Scale
M <- scale_rcpp(M, a, b)

###Attach dimension names
colnames(M) <- col_names
rownames(M) <- row_names

return(list(M=M, means=a, sds=b))
}

###Scale a matrix (similar to the above but does not use R to compute means and sds)
scale_fast2 <- function(M, scale=TRUE, na.rm=TRUE){
##Check whether M is a matrix. If not, coerce into it.
##Check whether M is a matrix. If not, coerce into it.
if(!is.matrix(M))
M <- as.matrix(M)

##Store dimnames
col_names <- colnames(M)
row_names <- rownames(M)

###Scale
out <- scale2_rcpp(M, scale=scale, na_rm=na.rm)
means <- drop(out$means)
sds <- drop(out$sds)
M <- out$M
rm(out)

###Attach dimension names
colnames(M) <- col_names
rownames(M) <- row_names
names(means) <- col_names
names(sds) <- col_names

return(list(M=M, means=means, sds=sds))
}

Expand All @@ -273,18 +273,18 @@ scale_fast2 <- function(M, scale=TRUE, na.rm=TRUE){
#' @importFrom ebnm ebnm_normal
#' @importFrom ebnm ebnm_normal_scale_mixture
#' @importFrom flashier flash
#'
#'
compute_cov_flash <- function(Y, error_cache = NULL){
covar <- diag(ncol(Y))
tryCatch({
fl <- flash(Y, var.type = 2,
ebnm.fn = c(ebnm_normal,ebnm_normal_scale_mixture),
fl <- flash(Y, var_type = 2,
ebnm_fn = c(ebnm_normal,ebnm_normal_scale_mixture),
backfit = TRUE,verbose = 0)
if (fl$n.factors == 0) {
covar <- diag(fl$residuals.sd^2)
if (fl$n_factors == 0) {
covar <- diag(fl$residuals_sd^2)
} else {
fsd <- sapply(fl$L.ghat,"[[","sd")
covar <- diag(fl$residuals.sd^2) + crossprod(t(fl$F.pm) * fsd)
fsd <- sapply(fl$L_ghat,"[[","sd")
covar <- diag(fl$residuals_sd^2) + crossprod(t(fl$F_pm) * fsd)
}
if (nrow(covar) == 0) {
covar <- diag(ncol(Y))
Expand All @@ -311,31 +311,31 @@ compute_cov_flash <- function(Y, error_cache = NULL){
#' @importFrom stats cov
compute_V_init <- function(X, Y, B, intercept, method=c("cov", "flash")){
method <- match.arg(method)

R <- removefromcols(Y, intercept) - X%*%B

if(method=="cov")
V <- cov(R)
else if(method=="flash")
V <- compute_cov_flash(R)

return(V)
}

###Extract Y missingness patterns for each individual
extract_missing_Y_pattern <- function(Y){

n <- nrow(Y)
miss <- vector("list", n)
non_miss <- vector("list", n)

for(i in 1:n){
miss_i <- is.na(Y[i, ])
non_miss_i <- !miss_i
miss[[i]] <- miss_i
non_miss[[i]] <- non_miss_i
}

return(list(miss=miss, non_miss=non_miss))
}

Loading