Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #38 from scverse/improve-simple-tests
Browse files Browse the repository at this point in the history
T-test / Improve simple tests
  • Loading branch information
grst authored Apr 2, 2024
2 parents 05fe4b5 + 0546ebf commit 7610d50
Show file tree
Hide file tree
Showing 20 changed files with 598 additions and 267 deletions.
27 changes: 25 additions & 2 deletions src/multi_condition_comparisions/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
from scipy.sparse import issparse, spmatrix


def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
"""Check if a matrix is numeric and only contains finite/non-NA values
Parameters
----------
array
dense or sparse matrix to check
Raises
------
ValueError
if the matrix is not numeric or contains NaNs or infinite values
"""
if not np.issubdtype(array.dtype, np.number):
raise ValueError("Counts must be numeric.")
if issparse(array):
if np.any(~np.isfinite(array.data)):
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
else:
if np.any(~np.isfinite(array)):
raise ValueError("Counts cannot contain negative, NaN or Inf values.")


def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
"""Check if a matrix container integers, or floats that are close to integers.
Expand All @@ -18,10 +41,10 @@ def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-
if the matrix contains valuese that are not close to integers
"""
if issparse(array):
if not array.data.dtype.kind == "i" or not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
raise ValueError("Non-zero elements of the matrix must be close to integer values.")
else:
if not array.dtype.kind == "i" or not np.all(np.abs(array - np.round(array)) < tolerance):
if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
raise ValueError("Matrix must be a count matrix.")
if (array < 0).sum() > 0:
raise ValueError("Non.zero elements of the matrix must be postiive.")
16 changes: 14 additions & 2 deletions src/multi_condition_comparisions/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from ._base import ContrastType, LinearModelBase, MethodBase
from ._edger import EdgeR
from ._pydeseq2 import PyDESeq2
from ._simple_tests import WilcoxonTest
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
from ._statsmodels import Statsmodels

__all__ = ["MethodBase", "LinearModelBase", "EdgeR", "PyDESeq2", "Statsmodels", "WilcoxonTest", "ContrastType"]
__all__ = [
"MethodBase",
"LinearModelBase",
"EdgeR",
"PyDESeq2",
"Statsmodels",
"SimpleComparisonBase",
"WilcoxonTest",
"TTest",
"ContrastType",
]

AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
84 changes: 54 additions & 30 deletions src/multi_condition_comparisions/methods/_base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from types import MappingProxyType

import numpy as np
import pandas as pd
from anndata import AnnData
from formulaic import model_matrix
from formulaic.model_matrix import ModelMatrix

from multi_condition_comparisions._util import check_is_numeric_matrix


@dataclass
class Contrast:
Expand All @@ -32,7 +35,7 @@ def __init__(
**kwargs,
):
"""
Initialize the method
Initialize the method.
Parameters
----------
Expand All @@ -53,17 +56,12 @@ def __init__(

self.layer = layer

# Do some sanity checks on the input. Do them after the mask is applied.
# Check that counts have no NaN or Inf values.
if np.any(~np.isfinite(self.data)):
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
# Check that counts have numeric values.
if not np.issubdtype(self.adata.X.dtype, np.number):
raise ValueError("Counts must be numeric.")
# Check after mask has been applied.
check_is_numeric_matrix(self.data)

@property
def data(self):
"""Get the data matrix from anndata this object was initalized with (X or layer)"""
"""Get the data matrix from anndata this object was initalized with (X or layer)."""
if self.layer is None:
return self.adata.X
else:
Expand All @@ -76,11 +74,13 @@ def compare_groups(
adata: AnnData,
column: str,
baseline: str,
groups_to_compare: str | Sequence[str],
groups_to_compare: str | Sequence[str] | None,
*,
paired_by: str = None,
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> pd.DataFrame:
"""
Compare between groups in a specified column.
Expand All @@ -97,10 +97,10 @@ def compare_groups(
column
column in obs that contains the grouping information
baseline
baseline value (one category from variable). If set to "None" this refers to "all other categories".
baseline value (one category from variable).
groups_to_compare
One or multiple categories from variable to compare against baseline. Setting this to None refers to
"all categories"
"all other categories"
paired_by
Column from `obs` that contains information about paired sample (e.g. subject_id)
mask
Expand Down Expand Up @@ -163,24 +163,48 @@ def compare_groups(
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
fit_kwargs: dict = None,
test_kwargs: dict = None,
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> pd.DataFrame:
if test_kwargs is None:
test_kwargs = {}
if fit_kwargs is None:
fit_kwargs = {}
"""
Compare between groups in a specified column.
This is a high-level interface that is kept simple on purpose and
only supports comparisons between groups on a single column at a time.
For more complex designs, please use the LinearModel method classes directly.
Parameters
----------
adata
AnnData object
column
column in obs that contains the grouping information
baseline
baseline value (one category from variable).
groups_to_compare
One or multiple categories from variable to compare against baseline. Setting this to None refers to
"all other categories"
paired_by
Column from `obs` that contains information about paired sample (e.g. subject_id)
mask
Subset anndata by a boolean mask stored in this column in `.obs` before making any tests
layer
Use this layer instead of `.X`.
Returns
-------
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this
is indicated in an additional column.
"""
design = f"~{column}"
if paired_by is not None:
design += f"+{paired_by}"
if isinstance(groups_to_compare, str):
groups_to_compare = [groups_to_compare]
model = cls(adata, design=design, mask=mask, layer=layer)

## Fit model
model.fit(**fit_kwargs)

## Test contrasts
de_res = model.test_contrasts(
{
group_to_compare: model.contrast(column=column, baseline=baseline, group_to_compare=group_to_compare)
Expand Down Expand Up @@ -234,15 +258,15 @@ def test_contrasts(
self, contrasts: list[str] | dict[str, np.ndarray] | dict[str, list] | np.ndarray, **kwargs
) -> pd.DataFrame:
"""
Conduct a specific test. Please use :method:`contrast` to build the contrasts instead of building it on your own.
Conduct a specific test.
Please use :method:`contrast` to build the contrasts instead of building it on your own.
Parameters
----------
contrasts:
either a single contrast, or a dictionary of contrasts where the key is the name for that particular contrast.
Each contrast can be either a vector of coefficients (the most general case), a string, or a some fancy DSL
(details still need to be figured out).
Each contrast can be either a vector of coefficients (the most general case), a string, a DSL (work in progress)
or a tuple with three elements contrasts = ("condition", "control", "treatment")
"""
if not isinstance(contrasts, dict):
Expand All @@ -252,9 +276,6 @@ def test_contrasts(
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))

results_df = pd.concat(results)
results_df.rename(
columns={"pvalue": "pvals", "padj": "pvals_adj", "log2FoldChange": "logfoldchanges"}, inplace=True
)

return results_df

Expand Down Expand Up @@ -328,5 +349,8 @@ def _get_var_from_colname(colname):
return self.design.model_spec.get_model_matrix(df)

def contrast(self, column: str, baseline: str, group_to_compare: str) -> list:
"""Build a simple contrast for pairwise comparisons. In the future all methods should be able to accept the output of :method:`StatsmodelsDE.contrast` but alas a big TODO."""
"""Build a simple contrast for pairwise comparisons.
In the future all methods should be able to accept the output of :method:`StatsmodelsDE.contrast` but alas a big TODO.
"""
return [column, baseline, group_to_compare]
9 changes: 6 additions & 3 deletions src/multi_condition_comparisions/methods/_edger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def fit(self, **kwargs): # adata, design, mask, layer
rpy2.robjects.numpy2ri.activate()

except ImportError:
raise ImportError("edger requires rpy2 to be installed. ") from None
raise ImportError("edger requires rpy2 to be installed.") from None

try:
edger = importr("edgeR")
except ImportError:
raise ImportError(
"edgeR requires a valid R installation with the following packages: " "edgeR, BiocParallel, RhpcBLASctl"
"edgeR requires a valid R installation with the following packages:\n"
"edgeR, BiocParallel, RhpcBLASctl"
) from None

# Convert dataframe
Expand Down Expand Up @@ -136,5 +137,7 @@ def _test_single_contrast(self, contrast: list[str], **kwargs) -> pd.DataFrame:

# Convert results to pandas
de_res = ro.conversion.rpy2py(ro.globalenv["de_res"])
de_res.index.name = "variable"
de_res = de_res.reset_index()

return de_res.rename(columns={"PValue": "pvals", "logFC": "logfoldchanges", "FDR": "pvals_adj"})
return de_res.rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
28 changes: 25 additions & 3 deletions src/multi_condition_comparisions/methods/_pydeseq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import warnings

import pandas as pd
from anndata import AnnData
from numpy import ndarray
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
from scipy.sparse import issparse

from multi_condition_comparisions._util import check_is_integer_matrix

Expand All @@ -15,6 +18,18 @@
class PyDESeq2(LinearModelBase):
"""Differential expression test using a PyDESeq2"""

def __init__(
self, adata: AnnData, design: str | ndarray, *, mask: str | None = None, layer: str | None = None, **kwargs
):
super().__init__(adata, design, mask=mask, layer=layer, **kwargs)
# work around pydeseq2 issue with sparse matrices
# see also https://github.com/owkin/PyDESeq2/issues/25
if issparse(self.data):
if self.layer is None:
self.adata.X = self.adata.X.toarray()
else:
self.adata.layers[self.layer] = self.adata.layers[self.layer].toarray()

def _check_counts(self):
check_is_integer_matrix(self.data)

Expand Down Expand Up @@ -50,7 +65,7 @@ def fit(self, **kwargs) -> pd.DataFrame:

def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd.DataFrame:
"""
Conduct a specific test and returns a data frame
Conduct a specific test and returns a Pandas DataFrame.
Parameters
----------
Expand All @@ -62,6 +77,13 @@ def _test_single_contrast(self, contrast: list[str], alpha=0.05, **kwargs) -> pd
kwargs: extra arguments to pass to DeseqStats()
"""
stat_res = DeseqStats(self.dds, contrast=contrast, alpha=alpha, **kwargs)
# Calling `.summary()` is required to fill the `results_df` data frame
stat_res.summary()
stat_res.p_values
return pd.DataFrame(stat_res.results_df).sort_values("padj")
res_df = (
pd.DataFrame(stat_res.results_df)
.rename(columns={"pvalue": "p_value", "padj": "adj_p_value", "log2FoldChange": "log_fc"})
.sort_values("p_value")
)
res_df.index.name = "variable"
res_df = res_df.reset_index()
return res_df
Loading

0 comments on commit 7610d50

Please sign in to comment.