Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Commit

Permalink
ENH Improve StatsmodelsDE with GLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisMuzellec authored Nov 28, 2023
2 parents 472bbef + 361cc08 commit 75b6de4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
45 changes: 38 additions & 7 deletions src/multi_condition_comparisions/tl/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
import scanpy as sc
import statsmodels.regression.linear_model
import statsmodels.api as sm
from anndata import AnnData
from formulaic import model_matrix
from formulaic.model_matrix import ModelMatrix
Expand All @@ -12,7 +12,12 @@

class BaseMethod(ABC):
def __init__(
self, adata: AnnData, design: str | np.ndarray, mask: str | None = None, layer: str | None = None, **kwargs
self,
adata: AnnData,
design: str | np.ndarray,
mask: str | None = None,
layer: str | None = None,
**kwargs,
):
"""
Initialize the method
Expand All @@ -24,7 +29,9 @@ def __init__(
design
Model design. Can be either a design matrix, a formulaic formula.
mask
a column in adata.var that contains a boolean mask with selected features.
A column in adata.var that contains a boolean mask with selected features.
layer
Layer to use in fit(). If None, use the X matrix.
**kwargs
Keyword arguments specific to the method implementation
"""
Expand Down Expand Up @@ -137,12 +144,36 @@ def contrast(self, column: str, baseline: str, group_to_compare: str) -> np.ndar
class StatsmodelsDE(BaseMethod):
"""Differential expression test using a statsmodels linear regression"""

def fit(self):
"""Fit the OLS model"""
def fit(
self,
regression_model: sm.OLS | sm.GLM = sm.OLS,
**kwargs,
) -> None:
"""
Fit the specified regression model.
Parameters
----------
regression_model
A statsmodels regression model class, either OLS or GLM. Defaults to OLS.
**kwargs
Additional arguments for fitting the specific method. In particular, this
is where you can specify the family for GLM.
Example
-------
>>> import statsmodels.api as sm
>>> model = StatsmodelsDE(adata, design="~condition")
>>> model.fit(sm.GLM, family=sm.families.NegativeBinomial(link=sm.families.links.Log()))
>>> results = model.test_contrasts(np.array([0, 1]))
"""
self.models = []
for var in tqdm(self.adata.var_names):
mod = statsmodels.regression.linear_model.OLS(
sc.get.obs_df(self.adata, keys=[var], layer=self.layer)[var], self.design
mod = regression_model(
sc.get.obs_df(self.adata, keys=[var], layer=self.layer)[var],
self.design,
**kwargs,
)
mod = mod.fit()
self.models.append(mod)
Expand Down
18 changes: 15 additions & 3 deletions tests/test_de.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata as ad
import numpy as np
import pytest
import statsmodels.api as sm
from pydeseq2.utils import load_example_data

import multi_condition_comparisions
Expand Down Expand Up @@ -28,11 +29,22 @@ def test_adata():
return ad.AnnData(X=counts, obs=metadata)


@pytest.mark.parametrize("method_class", [StatsmodelsDE])
def test_de(test_adata, method_class: BaseMethod):
@pytest.mark.parametrize(
"method_class,kwargs",
[
# OLS
(StatsmodelsDE, {}),
# Negative Binomial
(
StatsmodelsDE,
{"regression_model": sm.GLM, "family": sm.families.NegativeBinomial()},
),
],
)
def test_de(test_adata, method_class: BaseMethod, kwargs):
"""Check that the method can be initialized and fitted and that the test_contrast
method returns a dataframe with the correct number of rows"""
method = method_class(adata=test_adata, design="~condition")
method.fit()
method.fit(**kwargs)
res_df = method.test_contrasts(np.array([0, 1]))
assert len(res_df) == test_adata.n_vars

0 comments on commit 75b6de4

Please sign in to comment.