Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

T-test / Improve simple tests #38

Merged
merged 22 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/multi_condition_comparisions/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from ._base import ContrastType, LinearModelBase, MethodBase
from ._edger import EdgeR
from ._pydeseq2 import PyDESeq2
from ._simple_tests import WilcoxonTest
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
from ._statsmodels import Statsmodels

__all__ = ["MethodBase", "LinearModelBase", "EdgeR", "PyDESeq2", "Statsmodels", "WilcoxonTest", "ContrastType"]
__all__ = [
"MethodBase",
"LinearModelBase",
"EdgeR",
"PyDESeq2",
"Statsmodels",
"SimpleComparisonBase",
"WilcoxonTest",
"TTest" "ContrastType",
]
21 changes: 10 additions & 11 deletions src/multi_condition_comparisions/methods/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from types import MappingProxyType

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -76,11 +77,13 @@ def compare_groups(
adata: AnnData,
column: str,
baseline: str,
groups_to_compare: str | Sequence[str],
groups_to_compare: str | Sequence[str] | None,
*,
paired_by: str = None,
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> pd.DataFrame:
"""
Compare between groups in a specified column.
Expand All @@ -97,10 +100,10 @@ def compare_groups(
column
column in obs that contains the grouping information
baseline
baseline value (one category from variable). If set to "None" this refers to "all other categories".
baseline value (one category from variable).
groups_to_compare
One or multiple categories from variable to compare against baseline. Setting this to None refers to
"all categories"
"all other categories"
paired_by
Column from `obs` that contains information about paired sample (e.g. subject_id)
mask
Expand Down Expand Up @@ -163,13 +166,9 @@ def compare_groups(
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
fit_kwargs: dict = None,
test_kwargs: dict = None,
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> pd.DataFrame:
if test_kwargs is None:
test_kwargs = {}
if fit_kwargs is None:
fit_kwargs = {}
design = f"~{column}"
if paired_by is not None:
design += f"+{paired_by}"
Expand Down
136 changes: 95 additions & 41 deletions src/multi_condition_comparisions/methods/_simple_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,50 @@

import warnings
from abc import abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from types import MappingProxyType

import numpy as np
import pandas as pd
import scipy.stats
from anndata import AnnData
from pandas.core.api import DataFrame as DataFrame
from scipy.sparse import issparse
from scipy.sparse import diags, issparse
from tqdm.auto import tqdm

from ._base import MethodBase


class SimpleComparisonBase(MethodBase):
@staticmethod
@abstractmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
"""
Perform a statistical test between values in x0 and x1.

If `paired` is True, x0 and x1 must be of the same length and ordered such that
paired elements have the same position.

Parameters
----------
x0
array with baseline values
x1
array with values to compare
paired
indicates whether to perform a paired test
**kwargs
kwargs passed to the test function

Returns
-------
p-value
"""
...

def _compare_single_group(
self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool = False
) -> pd.DataFrame:
self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool, **kwargs
) -> DataFrame:
"""
Perform a single comparison between two groups.

Expand All @@ -32,8 +58,32 @@ def _compare_single_group(
paired
whether or not to perform a paired test. Note that in the case of a paired test,
the indices must be ordered such that paired observations appear at the same position.
**kwargs
kwargs passed to the test function
"""
...
if paired:
assert len(baseline_idx) == len(comparison_idx), "For a paired test, indices must be of the same length"

x0 = self.data[baseline_idx, :]
x1 = self.data[comparison_idx, :]

# In the following loop, we are doing a lot of column slicing -- which is significantly
# more efficient in csc format.
if issparse(self.data):
x0 = x0.tocsc()
x1 = x1.tocsc()

res = []
for var in tqdm(self.adata.var_names):
tmp_x0 = x0[:, self.adata.var_names == var]
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten()
tmp_x1 = x1[:, self.adata.var_names == var]
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten()
pval = self._test(tmp_x0, tmp_x1, paired, **kwargs)
mean_x0 = np.mean(x0)
mean_x1 = np.mean(x1)
res.append({"variable": var, "pvals": pval, "fold_change": np.log(mean_x1) - np.log(mean_x0)})
return pd.DataFrame(res).sort_values("pvals").set_index("variable")

@classmethod
def compare_groups(
Expand All @@ -46,30 +96,40 @@ def compare_groups(
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
fit_kwargs: dict = None,
test_kwargs: dict = None,
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> DataFrame:
if test_kwargs is None:
test_kwargs = {}
if fit_kwargs is None:
fit_kwargs = {}
if len(fit_kwargs) or len(test_kwargs):
warnings.warn("Simple tests do not use fit or test kwargs", UserWarning, stacklevel=2)
if paired_by is not None:
adata = adata.copy()[adata.obs.sort_values(paired_by).index, :]
if len(fit_kwargs):
warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2)
paired = paired_by is not None
model = cls(adata, mask=mask, layer=layer)
if groups_to_compare is None:
# compare against all other
groups_to_compare = sorted(set(model.adata.obs[column]) - {baseline})
if isinstance(groups_to_compare, str):
groups_to_compare = [groups_to_compare]

def _get_idx(column, value):
mask = model.adata.obs[column] == value
if paired:
dummies = pd.get_dummies(model.adata.obs[paired_by], sparse=True).sparse.to_coo().tocsr()
if not np.all(np.sum(dummies, axis=0) == 2):
raise ValueError("Pairing is only possible with exactly two values per group")
# Use matrix multiplication to only retreive those dummy entries that are associated with the current `value`.
# Convert to COO matrix to get rows/cols
# row indices refers to the indices of rows that have `column == value` (equivalent to np.where(mask)[0])
# col indices refers to the numeric index of each "pair" in obs_names
ind_mat = (diags(mask.values, dtype=bool) @ dummies).tocoo()
return ind_mat.row[np.argsort(ind_mat.col)]
else:
return np.where(mask)[0]

res_dfs = []
baseline_idx = _get_idx(column, baseline)
for group_to_compare in groups_to_compare:
comparison_idx = np.where(adata.obs[column] == group_to_compare)[0]
if baseline is None:
baseline_idx = np.where(adata.obs[column] != group_to_compare)[0]
else:
baseline_idx = np.where(adata.obs[column] == baseline)[0]
comparison_idx = _get_idx(column, group_to_compare)
res_dfs.append(
model._compare_single_group(baseline_idx, comparison_idx).assign(
model._compare_single_group(baseline_idx, comparison_idx, paired=paired, **test_kwargs).assign(
comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}"
)
)
Expand All @@ -82,26 +142,20 @@ class WilcoxonTest(SimpleComparisonBase):
(the former is also known as "Mann-Whitney U test", the latter as "wilcoxon signed rank test")
"""

def _compare_single_group(
self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool = False
) -> DataFrame:
@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
if paired:
assert len(baseline_idx) == len(comparison_idx), "For a paired test, indices must be of the same length"
test_fun = scipy.stats.wilcoxon
return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue
else:
test_fun = scipy.stats.mannwhitneyu
return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue
grst marked this conversation as resolved.
Show resolved Hide resolved

# TODO can be made more efficient by converting CSR/CSC matrices accordingly
x0 = self.data[baseline_idx, :]
x1 = self.data[comparison_idx, :]
res = []
for var in tqdm(self.adata.var_names):
tmp_x0 = x0[:, self.adata.var_names == var]
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten()
tmp_x1 = x1[:, self.adata.var_names == var]
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten()
pval = test_fun(x=tmp_x0, y=tmp_x1).pvalue
mean_x0 = np.asarray(np.mean(x0, axis=0)).flatten().astype(dtype=float)
mean_x1 = np.asarray(np.mean(x1, axis=0)).flatten().astype(dtype=float)
res.append({"variable": var, "pvals": pval, "fold_change": np.log(mean_x1) - np.log(mean_x0)})
return pd.DataFrame(res).sort_values("pvals").set_index("variable")

class TTest(SimpleComparisonBase):
"""Perform a unpaired or paired T-test"""

@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
if paired:
return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue
else:
return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue
grst marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
import pytest
import scipy.sparse as sp
from pydeseq2.utils import load_example_data


Expand All @@ -28,21 +29,21 @@ def test_adata(test_counts, test_metadata):
return ad.AnnData(X=test_counts, obs=test_metadata)


@pytest.fixture
def test_adata_minimal():
@pytest.fixture(params=[np.array, sp.csr_matrix, sp.csc_matrix])
def test_adata_minimal(request):
grst marked this conversation as resolved.
Show resolved Hide resolved
n_obs = 80
n_donors = n_obs // 4
rng = np.random.default_rng(9) # make tests deterministic
obs = pd.DataFrame(
{
"condition": ["A", "B"] * (n_obs // 2),
"donor": sum(([f"D{i}"] * n_donors for i in range(n_obs // n_donors)), []),
"other": (["X"] * (n_obs // 4)) + (["Y"] * ((3 * n_obs) // 4)),
"pairing": sum(([str(i), str(i)] for i in range(n_obs // 2)), []),
"continuous": np.random.uniform(0, 1) * 4000,
"continuous": [rng.uniform(0, 1) * 4000 for _ in range(n_obs)],
},
)
var = pd.DataFrame(index=["gene1", "gene2"])
rng = np.random.default_rng(9) # make tests deterministic
group1 = rng.negative_binomial(20, 0.1, n_obs // 2) # large mean
group2 = rng.negative_binomial(5, 0.5, n_obs // 2) # small mean

Expand All @@ -57,5 +58,5 @@ def test_adata_minimal():
donor_data[(2 * n_donors) : (3 * n_donors)] = group2[:n_donors]
donor_data[(3 * n_donors) :] = group1[n_donors:]

X = np.vstack([condition_data, donor_data]).T
X = request.param(np.vstack([condition_data, donor_data]).T)
grst marked this conversation as resolved.
Show resolved Hide resolved
return ad.AnnData(X=X, obs=obs, var=var)
47 changes: 44 additions & 3 deletions tests/test_de.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import pandas as pd
import pytest
import statsmodels.api as sm
from pandas import testing as tm
from pandas.core.api import DataFrame as DataFrame

import multi_condition_comparisions
from multi_condition_comparisions.methods import EdgeR, PyDESeq2, Statsmodels, WilcoxonTest
from multi_condition_comparisions.methods import EdgeR, PyDESeq2, SimpleComparisonBase, Statsmodels, TTest, WilcoxonTest


def test_package_has_version():
Expand All @@ -23,7 +25,7 @@ def test_package_has_version():
),
],
)
def test_de(test_adata, method_class, kwargs):
def test_statsmodels(test_adata, method_class, kwargs):
"""Check that the method can be initialized and fitted, and perform basic checks on
the result of test_contrasts."""
method = method_class(adata=test_adata, design="~condition") # type: ignore
Expand Down Expand Up @@ -98,10 +100,49 @@ def test_edger_complex(test_adata):


@pytest.mark.parametrize("paired_by", [None, "pairings"])
def test_non_parametric(test_adata, paired_by):
def test_wilcoxon(test_adata, paired_by):
if paired_by is not None:
test_adata.obs[paired_by] = list(range(sum(test_adata.obs["condition"] == "A"))) * 2
res_df = WilcoxonTest.compare_groups(
adata=test_adata, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by
)
assert np.all((0 <= res_df["pvals"]) & (res_df["pvals"] <= 1)) # TODO: which of these should actually be <.05?


@pytest.mark.parametrize("paired_by", [None, "pairings"])
def test_t(test_adata, paired_by):
if paired_by is not None:
test_adata.obs[paired_by] = list(range(sum(test_adata.obs["condition"] == "A"))) * 2
res_df = TTest.compare_groups(
adata=test_adata, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by
)
assert np.all((0 <= res_df["pvals"]) & (res_df["pvals"] <= 1)) # TODO: which of these should actually be <.05?


@pytest.mark.parametrize("seed", range(10))
def test_simple_comparison_pairing(test_adata_minimal, seed):
"""Test that paired samples are properly matched in a paired test"""

class MockSimpleComparison(SimpleComparisonBase):
@staticmethod
def _test():
return None

def _compare_single_group(
self, baseline_idx: np.ndarray, comparison_idx: np.ndarray, *, paired: bool = False, **kwargs
) -> DataFrame:
assert paired
x0 = self.adata[baseline_idx, :]
x1 = self.adata[comparison_idx, :]
assert np.all(x0.obs["condition"] == "A")
assert np.all(x1.obs["condition"] == "B")
assert np.all(x0.obs["pairing"].values == x1.obs["pairing"].values)
return pd.DataFrame()

rng = np.random.default_rng(seed)
shuffle_adata_idx = rng.permutation(test_adata_minimal.obs_names)
tmp_adata = test_adata_minimal[shuffle_adata_idx, :].copy()

MockSimpleComparison.compare_groups(
tmp_adata, column="condition", baseline="A", groups_to_compare=["B"], paired_by="pairing"
)
4 changes: 2 additions & 2 deletions tests/test_unified_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

from multi_condition_comparisions.methods import EdgeR, PyDESeq2, Statsmodels, WilcoxonTest
from multi_condition_comparisions.methods import EdgeR, PyDESeq2, Statsmodels, TTest, WilcoxonTest


@pytest.mark.parametrize("method", [WilcoxonTest, Statsmodels, PyDESeq2, EdgeR])
@pytest.mark.parametrize("method", [WilcoxonTest, TTest, Statsmodels, PyDESeq2, EdgeR])
@pytest.mark.parametrize("paired_by", ["pairing", None])
def test_unified(test_adata_minimal, method, paired_by):
res_df = method.compare_groups(
Expand Down
Loading