From c7209d4bf5f8dbd325d0d77203027e333c5561f1 Mon Sep 17 00:00:00 2001 From: Amanda Potts Date: Mon, 24 Jun 2024 12:32:24 -0400 Subject: [PATCH] Closes #3354 import scipy.stats functions to arkouda.scipy.stats --- PROTO_tests/tests/dataframe_test.py | 2 +- PROTO_tests/tests/groupby_test.py | 2 +- PROTO_tests/tests/random_test.py | 2 +- .../{scipy_test.py => scipy_stats_test.py} | 4 +- arkouda/__init__.py | 4 +- arkouda/scipy/__init__.py | 4 +- arkouda/scipy/stats/__init__.py | 291 +++++++++++++++++- arkouda/scipy/{ => stats}/_stats_py.py | 8 +- pydoc/preprocess/generate_import_stubs.py | 45 ++- pytest.ini | 2 +- pytest_PROTO.ini | 2 +- scripts/exclude.py | 37 +++ tests/dataframe_test.py | 2 +- tests/groupby_test.py | 2 +- tests/random_test.py | 2 +- .../{scipy_test.py => scipy_stats_test.py} | 4 +- 16 files changed, 383 insertions(+), 30 deletions(-) rename PROTO_tests/tests/scipy/{scipy_test.py => scipy_stats_test.py} (93%) rename arkouda/scipy/{ => stats}/_stats_py.py (97%) create mode 100644 scripts/exclude.py rename tests/scipy/{scipy_test.py => scipy_stats_test.py} (94%) diff --git a/PROTO_tests/tests/dataframe_test.py b/PROTO_tests/tests/dataframe_test.py index 2ee38c060d..28f2e9ff59 100644 --- a/PROTO_tests/tests/dataframe_test.py +++ b/PROTO_tests/tests/dataframe_test.py @@ -10,7 +10,7 @@ import arkouda as ak from arkouda import io_util -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare def alternating_1_0(n): diff --git a/PROTO_tests/tests/groupby_test.py b/PROTO_tests/tests/groupby_test.py index 8602233be0..fb71192b9c 100644 --- a/PROTO_tests/tests/groupby_test.py +++ b/PROTO_tests/tests/groupby_test.py @@ -4,7 +4,7 @@ import arkouda as ak from arkouda.groupbyclass import GroupByReductionType -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare def to_tuple_dict(labels, values): diff --git a/PROTO_tests/tests/random_test.py b/PROTO_tests/tests/random_test.py index f447849930..530bcdbb32 100644 --- a/PROTO_tests/tests/random_test.py +++ b/PROTO_tests/tests/random_test.py @@ -6,7 +6,7 @@ from scipy import stats as sp_stats import arkouda as ak -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare class TestRandom: diff --git a/PROTO_tests/tests/scipy/scipy_test.py b/PROTO_tests/tests/scipy/scipy_stats_test.py similarity index 93% rename from PROTO_tests/tests/scipy/scipy_test.py rename to PROTO_tests/tests/scipy/scipy_stats_test.py index 863d8cb1a0..94ea2dbe0a 100644 --- a/PROTO_tests/tests/scipy/scipy_test.py +++ b/PROTO_tests/tests/scipy/scipy_stats_test.py @@ -4,8 +4,8 @@ from scipy.stats import power_divergence as scipy_power_divergence import arkouda as ak -from arkouda.scipy import chisquare as ak_chisquare -from arkouda.scipy import power_divergence as ak_power_divergence +from arkouda.scipy.stats import chisquare as ak_chisquare +from arkouda.scipy.stats import power_divergence as ak_power_divergence DDOF = [0, 1, 2, 3, 4, 5] PAIRS = [ diff --git a/arkouda/__init__.py b/arkouda/__init__.py index d6b8b50cf4..a7089b7a68 100644 --- a/arkouda/__init__.py +++ b/arkouda/__init__.py @@ -40,6 +40,4 @@ is_registered, broadcast_dims, ) -from arkouda.scipy.special import * -from arkouda.scipy import * -from arkouda.random import * +from arkouda.random import * \ No newline at end of file diff --git a/arkouda/scipy/__init__.py b/arkouda/scipy/__init__.py index d34f5e0f30..0dc31a2ab4 100644 --- a/arkouda/scipy/__init__.py +++ b/arkouda/scipy/__init__.py @@ -1,3 +1 @@ -from ._stats_py import Power_divergenceResult, chisquare, power_divergence - -__all__ = ["power_divergence", "chisquare", "Power_divergenceResult"] +__all__ = ["special", "stats"] diff --git a/arkouda/scipy/stats/__init__.py b/arkouda/scipy/stats/__init__.py index 046cbc676a..124079b36d 100644 --- a/arkouda/scipy/stats/__init__.py +++ b/arkouda/scipy/stats/__init__.py @@ -1,4 +1,291 @@ # mypy: ignore-errors -from scipy.stats import chi2 # type: ignore[import-untyped] +from scipy.stats import ( # noqa + ConstantInputWarning, + DegenerateDataWarning, + FitError, + NearConstantInputWarning, + alpha, + anglit, + arcsine, + argus, + bernoulli, + beta, + betabinom, + biasedurn, + binom, + binomtest, + boltzmann, + bradford, + burr, + burr12, + cauchy, + chi, + chi2, + contingency, + cosine, + crystalball, + dgamma, + distributions, + dlaplace, + dweibull, + erlang, + expon, + exponnorm, + exponpow, + exponweib, + f, + fatiguelife, + fisk, + foldcauchy, + foldnorm, + gamma, + gausshyper, + genexpon, + genextreme, + gengamma, + genhalflogistic, + genhyperbolic, + geninvgauss, + genlogistic, + gennorm, + genpareto, + geom, + gibrat, + gompertz, + gumbel_l, + gumbel_r, + halfcauchy, + halfgennorm, + halflogistic, + halfnorm, + hypergeom, + hypsecant, + invgamma, + invgauss, + invweibull, + johnsonsb, + johnsonsu, + kappa3, + kappa4, + kde, + ksone, + kstwo, + kstwobign, + laplace, + laplace_asymmetric, + levy, + levy_l, + levy_stable, + loggamma, + logistic, + loglaplace, + lognorm, + logser, + loguniform, + lomax, + maxwell, + mielke, + morestats, + moyal, + mstats_basic, + mstats_extras, + mvn, + nakagami, + nbinom, + ncf, + nchypergeom_fisher, + nchypergeom_wallenius, + nct, + ncx2, + nhypergeom, + norm, + norminvgauss, + pareto, + pearson3, + planck, + poisson, + poisson_means_test, + powerlaw, + powerlognorm, + powernorm, + randint, + rayleigh, + rdist, + recipinvgauss, + reciprocal, + rice, + sampling, + semicircular, + skellam, + skewcauchy, + skewnorm, + stats, + studentized_range, + t, + trapezoid, + trapz, + triang, + truncexpon, + truncnorm, + truncpareto, + truncweibull_min, + tukeylambda, + uniform, + vonmises, + vonmises_line, + wald, + weibull_max, + weibull_min, + wrapcauchy, + yulesimon, + zipf, + zipfian, +) -__all__ = ["chi2"] +from ._stats_py import Power_divergenceResult, chisquare, power_divergence # noqa + +scipy_imports = [ + "ConstantInputWarning", + "DegenerateDataWarning", + "FitError", + "NearConstantInputWarning", + "alpha", + "anglit", + "arcsine", + "argus", + "bernoulli", + "beta", + "betabinom", + "biasedurn", + "binom", + "binomtest", + "boltzmann", + "bradford", + "burr", + "burr12", + "cauchy", + "chi", + "chi2", + "contingency", + "cosine", + "crystalball", + "dgamma", + "distributions", + "dlaplace", + "dweibull", + "erlang", + "expon", + "exponnorm", + "exponpow", + "exponweib", + "f", + "fatiguelife", + "fisk", + "foldcauchy", + "foldnorm", + "gamma", + "gausshyper", + "genexpon", + "genextreme", + "gengamma", + "genhalflogistic", + "genhyperbolic", + "geninvgauss", + "genlogistic", + "gennorm", + "genpareto", + "geom", + "gibrat", + "gompertz", + "gumbel_l", + "gumbel_r", + "halfcauchy", + "halfgennorm", + "halflogistic", + "halfnorm", + "hypergeom", + "hypsecant", + "invgamma", + "invgauss", + "invweibull", + "johnsonsb", + "johnsonsu", + "kappa3", + "kappa4", + "kde", + "ksone", + "kstwo", + "kstwobign", + "laplace", + "laplace_asymmetric", + "levy", + "levy_l", + "levy_stable", + "loggamma", + "logistic", + "loglaplace", + "lognorm", + "logser", + "loguniform", + "lomax", + "maxwell", + "mielke", + "morestats", + "moyal", + "mstats_basic", + "mstats_extras", + "mvn", + "nakagami", + "nbinom", + "ncf", + "nchypergeom_fisher", + "nchypergeom_wallenius", + "nct", + "ncx2", + "nhypergeom", + "norm", + "norminvgauss", + "pareto", + "pearson3", + "planck", + "poisson", + "poisson_means_test", + "powerlaw", + "powerlognorm", + "powernorm", + "randint", + "rayleigh", + "rdist", + "recipinvgauss", + "reciprocal", + "rice", + "sampling", + "semicircular", + "skellam", + "skewcauchy", + "skewnorm", + "stats", + "studentized_range", + "t", + "test", + "trapezoid", + "trapz", + "triang", + "truncexpon", + "truncnorm", + "truncpareto", + "truncweibull_min", + "tukeylambda", + "uniform", + "vonmises", + "vonmises_line", + "wald", + "weibull_max", + "weibull_min", + "wrapcauchy", + "yulesimon", + "zipf", + "zipfian", +] + +__all__ = scipy_imports + ["Power_divergenceResult", "power_divergence", "chisquare"] diff --git a/arkouda/scipy/_stats_py.py b/arkouda/scipy/stats/_stats_py.py similarity index 97% rename from arkouda/scipy/_stats_py.py rename to arkouda/scipy/stats/_stats_py.py index 7a45576666..aa300d8e8d 100644 --- a/arkouda/scipy/_stats_py.py +++ b/arkouda/scipy/stats/_stats_py.py @@ -2,13 +2,13 @@ import numpy as np from numpy import asarray -from scipy.stats import chi2 # type: ignore +from scipy.stats import chi2 # type:ignore import arkouda as ak -from arkouda.scipy.special import xlogy from arkouda.dtypes import float64 as akfloat64 +from arkouda.scipy.special import xlogy -__all__ = ["power_divergence", "chisquare", "Power_divergenceResult"] +__all__ = ["Power_divergenceResult", "power_divergence", "chisquare"] class Power_divergenceResult(namedtuple("Power_divergenceResult", ("statistic", "pvalue"))): @@ -74,7 +74,7 @@ def power_divergence(f_obs, f_exp=None, ddof=0, lambda_=None): >>> import arkouda as ak >>> ak.connect() - >>> from arkouda.stats import power_divergence + >>> from arkouda.scipy.stats import power_divergence >>> x = ak.array([10, 20, 30, 10]) >>> y = ak.array([10, 30, 20, 10]) >>> power_divergence(x, y, lambda_="pearson") diff --git a/pydoc/preprocess/generate_import_stubs.py b/pydoc/preprocess/generate_import_stubs.py index d1ad75f3d8..5714d12b7a 100644 --- a/pydoc/preprocess/generate_import_stubs.py +++ b/pydoc/preprocess/generate_import_stubs.py @@ -3,14 +3,45 @@ import re +def clean_signature(signature: inspect.Signature) -> str: + + function_at_pattern = "\w+=" + object_at_pattern = "\w+=" + + signature_string = str(signature) + + match = re.findall(function_at_pattern, signature_string) + if len(match) > 0: + for key in signature.parameters.keys(): + param = signature.parameters.get(key) + default = param.default + if "function" in str(default): + mod = default.__module__ + name = default.__name__ + replacement = param.name + "=" + mod + "." + name + signature_string = re.sub(function_at_pattern, replacement, signature_string, count=1) + + match = re.findall(object_at_pattern, signature_string) + if len(match) > 0: + for key in signature.parameters.keys(): + param = signature.parameters.get(key) + default = param.default + if "object" in str(default): + replacement = param.name + "=object" + signature_string = re.sub(object_at_pattern, replacement, signature_string, count=1) + + signature_string = re.sub(object_at_pattern, "", signature_string, count=1) + return signature_string + + def insert_spaces_after_newlines(input_string, spaces): if input_string is not None: pattern = r"^\n(\s+)" - starting_indents = re.findall(pattern, input_string) + starting_indents = [item for item in re.findall(pattern, input_string) if len(item) > 0] if len(starting_indents) > 0: old_indent_pattern = "^" + starting_indents[0] else: - return input_string + old_indent_pattern = "^" lines = input_string.split("\n") result = [] @@ -47,9 +78,9 @@ def get_parent_class_str(obj): def write_formatted_docstring(f, doc_string, spaces): doc_string = insert_spaces_after_newlines(doc_string, spaces) if doc_string is not None and len(doc_string) > 0: - f.write(spaces + "r'''\n") + f.write(spaces + 'r"""\n') f.write(f"{doc_string}\n") - f.write(spaces + "'''") + f.write(spaces + '"""') f.write("\n" + spaces + "...") else: f.write("\n" + spaces + "...") @@ -81,7 +112,9 @@ def write_stub(module, filename, all_only=False, allow_arkouda=False): elif inspect.isfunction(obj): if not name.startswith("__"): try: - f.write(f"def {name}{inspect.signature(obj)}:\n") + signature = clean_signature(inspect.signature(obj)) + + f.write(f"def {name}{signature}:\n") except: f.write(f"def {name}(self, *args, **kwargs):\n") @@ -125,7 +158,7 @@ def write_stub(module, filename, all_only=False, allow_arkouda=False): f.write(f" def {func_name}{signature}:\n") else: try: - signature = str(inspect.signature(func)) + signature = clean_signature(inspect.signature(func)) if "self" not in signature: signature = signature.replace("(", "(self, ") f.write(f" def {func_name}{signature}:\n") diff --git a/pytest.ini b/pytest.ini index f8369bfe35..e805b1fa3e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,7 +2,7 @@ filterwarnings = ignore:Version mismatch between client .* testpaths = - tests/scipy/scipy_test.py + tests/scipy/scipy_stats_test.py tests/scipy/special_test.py tests/numpy/numpy_test.py tests/alignment_tests.py diff --git a/pytest_PROTO.ini b/pytest_PROTO.ini index e86793eada..46553186fc 100644 --- a/pytest_PROTO.ini +++ b/pytest_PROTO.ini @@ -29,7 +29,7 @@ testpaths = PROTO_tests/tests/pdarray_creation_test.py PROTO_tests/tests/random_test.py PROTO_tests/tests/regex_test.py - PROTO_tests/tests/scipy/scipy_test.py + PROTO_tests/tests/scipy/scipy_stats_test.py PROTO_tests/tests/security_test.py PROTO_tests/tests/segarray_test.py PROTO_tests/tests/series_test.py diff --git a/scripts/exclude.py b/scripts/exclude.py new file mode 100644 index 0000000000..2e9d8ef675 --- /dev/null +++ b/scripts/exclude.py @@ -0,0 +1,37 @@ +# This script finds all the classes and functions who's doc string contains "array". These functions are most likely not natively compatible with arkouda and should probalby be excluded from any automatic imports. + +import inspect + +def exclude(name, obj)->bool: + if hasattr(obj, "__doc__") and not name.startswith("__") and obj.__doc__ is not None: + if "array" in obj.__doc__: + return True + return False + + +def main(): + + import scipy.stats as scipyStats + + exclude_set = set() + for name, obj in inspect.getmembers(scipyStats): + if exclude(name, obj): + exclude_set.add(name) + if inspect.isclass(obj): + for func_name, func in inspect.getmembers(obj): + if exclude(name, obj): + exclude_set.add(name) + + exclude_list = list(exclude_set) + keep_list = list(set(dir(scipyStats)).difference(exclude_list)) + keep_list = [item for item in keep_list if not item.startswith("_")] + keep_list = sorted(keep_list) + + print("EXCLUDE:") + print(exclude_list) + print("KEEP:") + print(keep_list) + + +if __name__ == "__main__": + main() diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py index 8366ca4b4d..113552e804 100644 --- a/tests/dataframe_test.py +++ b/tests/dataframe_test.py @@ -13,7 +13,7 @@ from arkouda import io_util from arkouda.index import Index -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare def build_ak_df(): diff --git a/tests/groupby_test.py b/tests/groupby_test.py index 3144d62b6a..e868b3bc9e 100644 --- a/tests/groupby_test.py +++ b/tests/groupby_test.py @@ -5,7 +5,7 @@ from arkouda.dtypes import float64, int64 from arkouda.groupbyclass import GroupByReductionType -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare SIZE = 100 GROUPS = 8 diff --git a/tests/random_test.py b/tests/random_test.py index c930cd6d6d..1b398d76f9 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -6,7 +6,7 @@ from context import arkouda as ak from scipy import stats as sp_stats -from arkouda.scipy import chisquare as akchisquare +from arkouda.scipy.stats import chisquare as akchisquare class RandomTest(ArkoudaTest): diff --git a/tests/scipy/scipy_test.py b/tests/scipy/scipy_stats_test.py similarity index 94% rename from tests/scipy/scipy_test.py rename to tests/scipy/scipy_stats_test.py index 8dcb7f65a5..d9fc3f5ec0 100644 --- a/tests/scipy/scipy_test.py +++ b/tests/scipy/scipy_stats_test.py @@ -23,7 +23,7 @@ def create_stat_test_pairs(self): def test_power_divergence(self): from scipy.stats import power_divergence as scipy_power_divergence - from arkouda.scipy import power_divergence as ak_power_divergence + from arkouda.scipy.stats import power_divergence as ak_power_divergence pairs = self.create_stat_test_pairs() @@ -57,7 +57,7 @@ def test_power_divergence(self): def test_chisquare(self): from scipy.stats import chisquare as scipy_chisquare - from arkouda.scipy import chisquare as ak_chisquare + from arkouda.scipy.stats import chisquare as ak_chisquare pairs = self.create_stat_test_pairs()