From 361cc08ee15091d6cd7f8b40cd3f267c3d44786e Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 28 Nov 2023 15:21:07 +0100 Subject: [PATCH] test: add test for StatsmodelsDE with NegativeBinomial model --- tests/test_de.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_de.py b/tests/test_de.py index 8bb7589..26249af 100644 --- a/tests/test_de.py +++ b/tests/test_de.py @@ -1,6 +1,7 @@ import anndata as ad import numpy as np import pytest +import statsmodels.api as sm from pydeseq2.utils import load_example_data import multi_condition_comparisions @@ -28,11 +29,22 @@ def test_adata(): return ad.AnnData(X=counts, obs=metadata) -@pytest.mark.parametrize("method_class", [StatsmodelsDE]) -def test_de(test_adata, method_class: BaseMethod): +@pytest.mark.parametrize( + "method_class,kwargs", + [ + # OLS + (StatsmodelsDE, {}), + # Negative Binomial + ( + StatsmodelsDE, + {"regression_model": sm.GLM, "family": sm.families.NegativeBinomial()}, + ), + ], +) +def test_de(test_adata, method_class: BaseMethod, kwargs): """Check that the method can be initialized and fitted and that the test_contrast method returns a dataframe with the correct number of rows""" method = method_class(adata=test_adata, design="~condition") - method.fit() + method.fit(**kwargs) res_df = method.test_contrasts(np.array([0, 1])) assert len(res_df) == test_adata.n_vars