Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Dec 28, 2023
1 parent 3929f50 commit 8cc646d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
8 changes: 6 additions & 2 deletions src/multi_condition_comparisions/tl/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
# Do some sanity checks on the input. Do them after the mask is applied.

# Check that counts have no NaN or Inf values.
if np.any(adata.X < 0 or np.isnan(self.adata.X)) or np.any(np.isinf(self.adata.X)):
if np.any(np.logical_or(adata.X < 0, np.isnan(self.adata.X))) or np.any(np.isinf(self.adata.X)):
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
# Check that counts have numeric values.
if not np.issubdtype(self.adata.X.dtype, np.number):
Expand Down Expand Up @@ -132,7 +132,11 @@ def test_contrasts(self, contrasts: list[str] | dict[str, np.ndarray] | np.ndarr
for name, contrast in contrasts.items():
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))

return pd.concat(results)
results = pd.concat(results)
results.rename(columns={"pvalue": "pvals", "padj": "pvals_adj", "log2FoldChange": "logfoldchanges"},
inplace=True)

return results

def test_reduced(self, modelB: "BaseMethod") -> pd.DataFrame:
"""
Expand Down
32 changes: 16 additions & 16 deletions tests/test_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,44 @@ def test_package_has_version():
),
],
)
def test_de(test_adata, statsmodels_stub, kwargs):
def test_de(test_adata, method_class, kwargs):
"""Check that the method can be initialized and fitted, and perform basic checks on
the result of test_contrasts."""
method = statsmodels_stub(adata=test_adata, design="~condition") # type: ignore
method = method_class(adata=test_adata, design="~condition") # type: ignore
method.fit(**kwargs)
res_df = method.test_contrasts(np.array([0, 1]))
# Check that the result has the correct number of rows
assert len(res_df) == test_adata.n_vars


def test_pydeseq2de(test_adata):
"""Check that the pyDESeq2 method can be initialized and fitted and that the test_contrast
method returns a dataframe with the correct number of rows.
Now this is a separate
def test_pydeseq2_simple(test_adata):
"""Check that the pyDESeq2 method can be
1. Initialized
2. Fitted
3. and that test_contrast returns a DataFrame with the correct number of rows.
"""
method = PyDESeq2DE(adata=test_adata, design="~condition")
method.fit()
res_df = method.test_contrasts(["condition", "A", "B"])

assert len(res_df) == test_adata.n_vars


def test_pydeseq2de2(test_adata):
def test_pydeseq2_complex(test_adata):
"""Check that the pyDESeq2 method can be initialized with a different covariate name and fitted and that the test_contrast
method returns a dataframe with the correct number of rows.
Now this is a separate
"""
test_adata.obs["condition1"] = test_adata.obs["condition"].copy()
method = PyDESeq2DE(adata=test_adata, design="~condition1+group")
method.fit()
res_df = method.test_contrasts(["condition1", "A", "B"])

assert len(res_df) == test_adata.n_vars
# Check that the index of the result matches the var_names of the adata
# Check that the index of the result matches the var_names of the AnnData object
tm.assert_index_equal(test_adata.var_names, res_df.index, check_order=False, check_names=False)
# Check that there is a p-value column
assert "pvalue" in res_df.columns
# Check that p-values are between 0 and 1
assert np.all((0 <= res_df["pvalue"]) & (res_df["pvalue"] <= 1))

expected_columns = {"pvals", "pvals_adj", "logfoldchanges"}
assert expected_columns.issubset(set(res_df.columns))
assert np.all((0 <= res_df["pvals"]) & (res_df["pvals"] <= 1))

0 comments on commit 8cc646d

Please sign in to comment.