From 647f2578e9e74ff0b6d952c8c0db25ef9fb5bd9a Mon Sep 17 00:00:00 2001 From: HUANGERIC970823 Date: Sun, 18 Feb 2024 12:06:21 +0800 Subject: [PATCH] dunnett alter arg and function test --- scikit_posthocs/_posthocs.py | 10 ++++++++-- tests/test_posthocs.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/scikit_posthocs/_posthocs.py b/scikit_posthocs/_posthocs.py index 137c01d..23f24b3 100644 --- a/scikit_posthocs/_posthocs.py +++ b/scikit_posthocs/_posthocs.py @@ -1,5 +1,5 @@ import itertools as it -from typing import Tuple, Union +from typing import Tuple, Union, Literal import numpy as np import scipy.stats as ss from statsmodels.sandbox.stats.multicomp import multipletests @@ -1565,6 +1565,7 @@ def posthoc_dunnett(a: Union[list, np.ndarray, DataFrame], val_col: str = None, group_col: str = None, control: str = None, + alternative: Literal['two-sided', 'less', 'greater'] = 'two-sided', sort: bool = False, to_matrix: bool = True) -> Series | DataFrame: """ @@ -1591,6 +1592,11 @@ def posthoc_dunnett(a: Union[list, np.ndarray, DataFrame], Name of the control group within the `group_col` column. Values should have a nominal scale (categorical). Must be specified if `a` is a pandas + alternative : ['two-sided', 'less', or 'greater'], optional + Whether to get the p-value for the one-sided hypothesis + ('less' or 'greater') or for the two-sided hypothesis ('two-sided'). + Defaults to 'two-sided'. + sort : bool, optional Specifies whether to sort DataFrame by group_col or not. Recommended unless you sort your data manually. @@ -1618,7 +1624,7 @@ def posthoc_dunnett(a: Union[list, np.ndarray, DataFrame], control_data = x_embedded.loc[control] treatment_data = x_embedded.drop(control) - pvals = ss.dunnett(*treatment_data, control=control_data).pvalue + pvals = ss.dunnett(*treatment_data, control=control_data, alternative=alternative).pvalue multi_index = MultiIndex.from_product([[control], treatment_data.index.tolist()]) dunnett_sr = Series(pvals, index=multi_index) diff --git a/tests/test_posthocs.py b/tests/test_posthocs.py index 2cb19d4..d601caf 100644 --- a/tests/test_posthocs.py +++ b/tests/test_posthocs.py @@ -597,6 +597,21 @@ def test_posthoc_tukey(self): self.assertTrue(np.allclose(results, r_results, atol=1.e-3)) + def test_posthoc_dunnett(self): + r_results = [8.125844e-11, 2.427434e-01] + results = sp.posthoc_dunnett(self.df.sort_index(), val_col='pulse', group_col='kind', + control='rest', to_matrix=False) + + # scipy use randomized Quasi-Monte Carlo integration of the multivariate-t distribution + # to compute the p-values. The result may vary slightly from run to run. + # we run the test 1000 times (maximum absolute tolerance = 1.e-4 for example data) + is_close = [] + for i in range(1000): + is_close.append(np.allclose(results, r_results, atol=1.e-4)) + is_close = all(is_close) + self.assertTrue(is_close) + if __name__ == '__main__': unittest.main() +