-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First version of the cobiclust package
- Loading branch information
1 parent
6595a8d
commit 916582c
Showing
24 changed files
with
1,137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
^.*\.Rproj$ | ||
^\.Rproj\.user$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.Rproj.user | ||
.Rhistory | ||
.RData | ||
.Ruserdata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
Package: cobiclust | ||
Type: Package | ||
Title: Biclustering via Latent Block Model Adapted to Overdispersed Count Data | ||
Version: 0.1.0 | ||
Author: Julie Aubert | ||
Maintainer: Julie Aubert <[email protected]> | ||
Description: Implementation of a probabilistic method for biclustering | ||
adapted to overdispersed count data. It is a Gamma-Poisson Latent Block Model. | ||
It also implements two selection criteria in order to select the number of | ||
biclusters. | ||
Imports: | ||
cluster | ||
License: GPL-2 | ||
Encoding: UTF-8 | ||
LazyData: true | ||
RoxygenNote: 6.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(summary,cobiclustering) | ||
export(alpha_calculation) | ||
export(cobiclust) | ||
export(cobiclustering) | ||
export(dicho) | ||
export(foo_a) | ||
export(init_pam) | ||
export(is.cobiclustering) | ||
export(lb_calculation) | ||
export(penalty) | ||
export(qu_calculation) | ||
export(qukg_calculation) | ||
export(selection_criteria) | ||
import(cluster) | ||
importFrom(cluster,pam) | ||
importFrom(stats,dnbinom) | ||
importFrom(stats,median) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
# cobiclust R package | ||
# Copyright INRA 2017 | ||
# UMR MIA-Paris, AgroParisTech, INRA, Universite Paris-Saclay, 75005, Paris, France | ||
#################################################################################### | ||
#' Perform a biclustering adapted to overdispersed count data. | ||
#' | ||
#' @param x the input matrix of observed data. | ||
#' @param K an integer specifying the number of groups in rows. | ||
#' @param G an integer specifying the number of groups in columns. | ||
#' @param nu_j a vector of . The length is equal to the number of colums. | ||
#' @param a an numeric. | ||
#' @param akg a logical variable indicating whether to use a common dispersion parameter (akg = FALSE) or a dispersion parameter per cocluster (akg = TRUE). | ||
#' @param cvg_lim a number specifying the threshold used for convergence criterion (cvg_lim = 1e-05 by default). | ||
#' @param nbiter the maximal number of iterations for the global loop of variational EM algorithm (nbiter = 5000 by default). | ||
#' @return An object of class cobiclustering | ||
#' @examples | ||
#' npc <- c(50, 40) # nodes per class | ||
#' KG <- c(2, 3) # classes | ||
#' nm <- npc * KG # nodes | ||
#' Z <- diag( KG[1]) %x% matrix(1, npc[1], 1) | ||
#' W <- diag(KG[2]) %x% matrix(1, npc[2], 1) | ||
#' L <- 70 * matrix( runif( KG[1] * KG[2]), KG[1], KG[2]) | ||
#' M_in_expectation <- Z %*% L %*% t(W) | ||
#' size <- 50 | ||
#' M<-matrix( | ||
#' rnbinom( | ||
#' n = length(as.vector(M_in_expectation)), | ||
#' mu = as.vector(M_in_expectation), size = size) | ||
#' , nm[1], nm[2]) | ||
#' rownames(M) <- paste("OTU", 1:nrow(M), sep = "_") | ||
#' colnames(M) <- paste("S", 1:ncol(M), sep = "_") | ||
#' res <- cobiclust(M, K = 2, G = 3, nu_j = rep(1,120), a = 1/size, cvg_lim = 1e-5) | ||
#' @seealso \code{\link{cobiclustering}} for the cobiclustering class. | ||
#' @export | ||
#' | ||
|
||
cobiclust <- | ||
function(x, K = 2, G = 3, nu_j = NULL, a = NULL, akg = FALSE, cvg_lim = 1e-05, | ||
nbiter = 5000){ | ||
tol = 1e-04 | ||
# Parameter initialisation --------------------------- | ||
res_init <- init_pam(x = x, nu_j = nu_j, a = a, K = K, G = G, akg = akg) | ||
|
||
n <- nrow(x) | ||
m <- ncol(x) | ||
nu_j <- res_init$parameters$nu_j | ||
mu_i <- res_init$parameters$mu_i | ||
t_jg <- res_init$info$t_jg | ||
s_ik <- res_init$info$s_ik | ||
pi_c <- res_init$parameters$pi | ||
rho_c <- res_init$parameters$rho | ||
alpha_c <- matrix(nrow = K, ncol = G, res_init$parameters$alpha) | ||
a0 <- res_init$parameters$a | ||
exp_utilde <- res_init$info$exp_utilde | ||
exp_logutilde <- res_init$info$exp_logutilde | ||
lb <- NULL | ||
lbtt <- NULL | ||
|
||
# Global EM ----------------------------------------------------------------- | ||
j <- 0 | ||
crit <- 1 | ||
|
||
while((crit > cvg_lim) & (j < nbiter)) | ||
{ | ||
j = j+1 | ||
#cat("Iteration EM global ",j,"\\n") | ||
t_old <- t_jg | ||
s_old <- s_ik | ||
pi_old <- pi_c | ||
rho_old <- rho_c | ||
alpha_old <- alpha_c | ||
a_old <- a0 | ||
exp_utilde_old <- exp_utilde | ||
exp_logutilde_old <- exp_logutilde | ||
|
||
|
||
# EM on the rows ----------------------------------------------------------------- | ||
i <- 0 | ||
crit_em1 <- 1 | ||
while(crit_em1 > cvg_lim) | ||
{ | ||
i = i + 1 | ||
pi_old <- pi_c | ||
alpha_old1 <- alpha_c | ||
# s_ik ------------------ | ||
if (is.matrix(nu_j)){ | ||
s_ik_tmp1 <- sapply(1:K, FUN = function(k) log(pi_c[k]) + | ||
rowSums(sapply(1:ncol(x), FUN = function(j) rowSums( | ||
sapply(1:G, FUN = function(l) | ||
t_jg[j, l] * x[, j] * (log(alpha_c[k, l]) + exp_logutilde[, j]) | ||
- t_jg[j, l] * mu_i * nu_j[, j] * alpha_c[k, l] * exp_utilde[, j]))))) | ||
|
||
} else { | ||
s_ik_tmp1 <- sapply(1:K, FUN = function(k) log(pi_c[k]) + | ||
rowSums(sapply(1:ncol(x), FUN = function(j) rowSums( | ||
sapply(1:G, FUN = function(l) | ||
t_jg[j, l] * x[, j] * (log(alpha_c[k, l]) + exp_logutilde[, j]) | ||
- t_jg[j, l] * mu_i * nu_j[j] * alpha_c[k, l] * exp_utilde[, j]))))) | ||
|
||
} | ||
|
||
s_ik_tmp2 <-s_ik_tmp1 - rowMeans(s_ik_tmp1) | ||
s_ik_tmp <- apply(s_ik_tmp2, 2, FUN = function(x) exp(x) / rowSums(exp(s_ik_tmp2))) | ||
if (sum(is.nan(s_ik)) > 0) { | ||
rmax <- apply(s_ik_tmp1, 1, max) | ||
s_ik_tmp3 <- s_ik_tmp1 - rmax | ||
s_ik <- apply(s_ik_tmp3, 2, FUN = function(x) exp(x) / rowSums(exp(s_ik_tmp3))) | ||
} | ||
rm(s_ik_tmp1, s_ik_tmp2, s_ik_tmp) | ||
s_ik <- apply(s_ik, c(1, 2), FUN = function(x) if (x <= 0.5) max(tol, x) | ||
else if (x > 0.5) min(x, 1 - tol)) | ||
s_ik[s_ik == tol] <- tol / (ncol(s_ik) - 1) | ||
s_ik <- s_ik / rowSums(s_ik) | ||
|
||
# Update pi_c, rho_c, alpha, a ------------------ | ||
pi_c <- colMeans(s_ik) | ||
# alpha_c | ||
alpha_c <- alpha_calculation(s_ik = s_ik, t_jg = t_jg, nu_j = nu_j, mu_i = mu_i, | ||
K = K, G = G, x = x, exp_utilde = exp_utilde) | ||
# Calculations of exp_utilde and exp_logutilde ------------------ | ||
if (akg == FALSE) { | ||
qu_param <- qu_calculation(s_ik = s_ik, t_jg = t_jg, x = x, mu_i = mu_i, | ||
nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} else { | ||
qu_param <- qukg_calculation(s_ik = s_ik, t_jg = t_jg, x = x, mu_i = mu_i, | ||
nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} | ||
exp_utilde <- qu_param$exp_utilde | ||
exp_logutilde <- qu_param$exp_logutilde | ||
# Stopping criteria | ||
crit_em1 <- sum((sort(alpha_c) - sort(alpha_old1))^2) | ||
+ sum((sort(pi_c) - sort(pi_old))^2) | ||
} | ||
|
||
# Calculation of t_{c+1} ------------------ | ||
i <- 0 | ||
crit_em2 <- 1 | ||
while(crit_em2 > cvg_lim) | ||
{ | ||
i = i + 1 | ||
rho_old <- rho_c | ||
alpha_old2 <- alpha_c | ||
t_old <- t_jg | ||
if (is.matrix(nu_j)){ | ||
t_jg_tmp <- sapply(1:G, FUN = function(l) log(rho_c[l]) + | ||
rowSums(sapply(1:nrow(x), FUN = function(i) | ||
rowSums(sapply(1:K, FUN = function(k) s_ik[i, k] * x[i,] | ||
* (log(alpha_c[k, l]) + exp_logutilde[i, ]) | ||
- s_ik[i, k] * mu_i[i] * nu_j[i,] * alpha_c[k, l] | ||
* exp_utilde[i, ]))))) | ||
} else { | ||
t_jg_tmp <- sapply(1:G, FUN = function(l) log(rho_c[l]) + | ||
rowSums(sapply(1:nrow(x), FUN = function(i) | ||
rowSums(sapply(1:K, FUN = function(k) s_ik[i, k] * x[i,] | ||
* (log(alpha_c[k, l]) + exp_logutilde[i, ]) | ||
- s_ik[i, k] * mu_i[i] * nu_j * alpha_c[k, l] | ||
* exp_utilde[i, ]))))) | ||
} | ||
t_jg_tmp2 <-t_jg_tmp - rowMeans(t_jg_tmp) | ||
t_jg <- apply(t_jg_tmp2, 2, FUN = function(x) exp(x) / rowSums(exp(t_jg_tmp2))) | ||
if (sum(is.nan(t_jg)) > 0) { | ||
rmax <- apply(t_jg_tmp, 1, max) | ||
t_jg_tmp3 <- t_jg_tmp - rmax | ||
t_jg <- apply(t_jg_tmp3, 2, FUN = function(x) exp(x) / rowSums(exp(t_jg_tmp3))) | ||
} | ||
|
||
t_jg <- apply(t_jg, c(1, 2), FUN = function(x) if (x <= 0.5) max(tol, x) | ||
else if (x > 0.5) min(x, 1 - tol)) | ||
t_jg[t_jg == tol] <- (tol) / (ncol(t_jg) - 1) | ||
|
||
t_jg <- t_jg / rowSums(t_jg) | ||
# Update of pi_c, rho_c, alpha, a ------------------------------------ | ||
rho_c <- colMeans(t_jg) | ||
|
||
#--------------- alpha | ||
alpha_c <- alpha_calculation(s_ik = s_ik, t_jg = t_jg, nu_j = nu_j, mu_i = mu_i, | ||
K = K, G = G, x = x, exp_utilde = exp_utilde) | ||
|
||
# calculations of exp_utilde and exp_logutilde ------------------ | ||
if (akg == FALSE) { | ||
qu_param <- qu_calculation(s_ik = s_ik, t_jg = t_jg, x = x, mu_i = mu_i, nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} else { | ||
qu_param <- qukg_calculation(s_ik = s_ik, t_jg = t_jg, x = x, mu_i = mu_i, | ||
nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} | ||
exp_utilde <- qu_param$exp_utilde | ||
exp_logutilde <- qu_param$exp_logutilde | ||
crit_em2 <- sum((sort(rho_c) - sort(rho_old))^2) | ||
+ sum((sort(alpha_c) - sort(alpha_old2))^2) | ||
} | ||
|
||
# Estimation of mu_i ------------------ | ||
if (is.matrix(nu_j)) { | ||
mu_i <- rowSums(x %*% t_jg) / | ||
rowSums(s_ik * (nu_j %*% tcrossprod (t_jg, alpha_c) ) ) | ||
} else { | ||
mu_i <- rowSums(x %*% t_jg) / | ||
rowSums(s_ik * rowSums(matrix(nrow = K, | ||
sapply(1:G, FUN = function(l) alpha_c[, l] * | ||
colSums(t_jg * nu_j)[l])))) | ||
} | ||
# Estimation of a and update of exp_utilde, exp_logutilde, alpha_c -------------------- | ||
lb_old <- lb | ||
a_old <- a0 | ||
if (is.null(a)){ | ||
if (akg == FALSE){ | ||
left_bound = sum(exp_logutilde) | ||
right_bound = sum(exp_utilde) | ||
a0 <- dicho(x = 0.01, y = abs(max(left_bound, right_bound)), threshold = 1e-08, | ||
nb = n * m, left_bound = left_bound, right_bound = right_bound) | ||
} else { | ||
left_bound = crossprod(s_ik, exp_logutilde %*% t_jg) | ||
right_bound = crossprod(s_ik, exp_utilde %*% t_jg) | ||
n_kg <- crossprod(s_ik, matrix(nrow = n, ncol = m, 1) %*% t_jg) | ||
a0 <- matrix(nrow = K, ncol = G, sapply(1:(K*G), | ||
FUN = function(g) dicho(x = 0.01, y = 100, threshold = 1e-08, nb = n_kg[g], left_bound = left_bound[g], right_bound = right_bound[g]))) | ||
|
||
} | ||
} | ||
if (akg == FALSE) { | ||
qu_param <- qu_calculation(s_ik = s_ik, t_jg = t_jg, x = x, | ||
mu_i = mu_i, nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} else { | ||
qu_param <- qukg_calculation(s_ik = s_ik, t_jg = t_jg, x = x, mu_i = mu_i, | ||
nu_j = nu_j, alpha_c = alpha_c, a = a0) | ||
} | ||
alpha_c <- alpha_calculation(s_ik = s_ik, t_jg = t_jg, nu_j = nu_j, mu_i = mu_i, | ||
K = K, G = G, x = x, exp_utilde = qu_param$exp_utilde) | ||
# Calculation of the lower bound ---------------------------------------- | ||
lb_out <- lb_calculation(x = x, qu_param = qu_param, s_ik = s_ik, pi_c = pi_c, | ||
t_jg = t_jg, rho_c = rho_c, mu_i = mu_i, | ||
nu_j = nu_j, alpha_c = alpha_c, a = a0, akg = akg) | ||
|
||
lb <- lb_out$lb | ||
# Stopping criterion based on the lower bound -------------------- | ||
if (j > 1){ | ||
crit <- abs((lb - lb_old) / lb_old) | ||
} | ||
lbtt <- c(lbtt, lb) | ||
} | ||
colclass <- apply(t_jg, 1, which.max) | ||
rowclass <- apply(s_ik, 1, which.max) | ||
# Output | ||
strategy <- list(akg = akg, cvg_lim = cvg_lim) | ||
parameters <- list(alpha = alpha_c, pi = pi_c, rho = rho_c, mu_i = mu_i, nu_j = nu_j, a = a0) | ||
info <- list(s_ik = s_ik, t_jg = t_jg, exp_logutilde = qu_param$exp_logutilde, exp_utilde = qu_param$exp_utilde, | ||
lb = lb, ent_ZW = lb_out$ent_ZW, nbiter = j, lbtt = lbtt, a_tilde = qu_param$a_tilde, b_tilde = qu_param$b_tilde) | ||
output <- list(data = x, K = K, G = G, classification = list(rowclass = rowclass, colclass = colclass), | ||
strategy = strategy, parameters = parameters, info = info) | ||
class(output) <- append(class(output),"cobiclustering") | ||
return(output) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# cobiclust R package | ||
# Copyright INRA 2017 | ||
# UMR MIA-Paris, AgroParisTech, INRA, Universite Paris-Saclay, 75005, Paris, France | ||
#################################################################################### | ||
#' Creation of the cobiclustering class. | ||
#' @export cobiclustering | ||
#' @keywords internal | ||
|
||
|
||
cobiclustering <- function(data = matrix(nrow = 3, ncol = 3, NA), K = 2, G = 2, classification = list(length=2), strategy = list(), parameters = list(), info = list()) | ||
{ | ||
|
||
object <- list( | ||
data = "matrix", | ||
K = "numeric", | ||
G = "numeric", | ||
classification = "list", | ||
strategy = "list", | ||
parameters = "list", | ||
info = "list" | ||
) | ||
#UseMethod("new",object) | ||
## Set the name for the class | ||
class(object) <- "cobiclustering" | ||
return(object) | ||
} | ||
|
||
#' Is an object of class cobiclustering ? | ||
#' | ||
#' @param object an object of class cobiclustering. | ||
#' @export is.cobiclustering | ||
#' @keywords internal | ||
is.cobiclustering <- function(object) { | ||
return(class(object) == "cobiclustering") | ||
} | ||
|
||
#' Summary of an object of class Cobiclust | ||
#' | ||
#' @param object an object of class cobiclustering. | ||
#' @param ... ignored | ||
#' @keywords internal | ||
#' @method summary cobiclustering | ||
#' @export | ||
summary.cobiclustering <- function (object, ...){ | ||
cat("-----------------------------------------------------------\n") | ||
cat("Number of biclusters: K = ", object$K, "* G = ", object$G,"\n") | ||
cat("Row proportions:", round(object$parameters$pi, 3) ,"\n") | ||
cat("Column proportions:", round(object$parameters$rho, 3) ,"\n") | ||
cat("-----------------------------------------------------------\n") | ||
cat("Matrix of alpha_kg interactions terms: \n") | ||
print(matrix(nrow = object$K, ncol = object$G, round(object$parameters$alpha, 3))) | ||
cat("-----------------------------------------------------------\n") | ||
cat("Frequencies of MAP classification \n") | ||
cat("Rows") | ||
print(table(object$classification$rowclass)) | ||
cat("Columns") | ||
print(table(object$classification$colclass)) | ||
cat(" \n") | ||
NextMethod("summary",object) | ||
cat("-----------------------------------------------------------\n") | ||
} |
Oops, something went wrong.