From 2dd89cd3fddec355489f1ff098e4582862245d09 Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Tue, 15 Oct 2024 15:01:41 +0200 Subject: [PATCH] switch to dict (#2111) --- sklearnex/tests/test_common.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sklearnex/tests/test_common.py b/sklearnex/tests/test_common.py index b0ec5992e7..f427dfb982 100644 --- a/sklearnex/tests/test_common.py +++ b/sklearnex/tests/test_common.py @@ -306,12 +306,12 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch): cache.set("key", key) cache.set( "text", - [ - re.findall(regex_func, text), - text, - [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)], - [""] + re.findall(regex_callingline, text), - ], + { + "funcs": re.findall(regex_func, text), + "trace": text, + "modules": [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)], + "callingline": [""] + re.findall(regex_callingline, text), + }, ) return cache.get("text", None) @@ -322,8 +322,8 @@ def call_validate_data(text, estimator, method): called once before offloading to oneDAL in sklearnex""" try: # get last to_table call showing end of oneDAL input portion of code - idx = len(text[0]) - 1 - text[0][::-1].index("to_table") - validfuncs = text[0][:idx] + idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("to_table") + validfuncs = text["funcs"][:idx] except ValueError: pytest.skip("onedal backend not used in this function") @@ -341,16 +341,17 @@ def n_jobs_check(text, estimator, method): """verify the n_jobs is being set if '_get_backend' or 'to_table' is called""" # remove the _get_backend function from sklearnex from considered _get_backend count = max( - text[0].count("to_table"), + text["funcs"].count("to_table"), len( [ i - for i in range(len(text[0])) - if text[0][i] == "_get_backend" and "sklearnex" not in text[2][i] + for i in range(len(text["funcs"])) + if text["funcs"][i] == "_get_backend" + and "sklearnex" not in text["modules"][i] ] ), ) - n_jobs_count = text[0].count("n_jobs_wrapper") + n_jobs_count = text["funcs"].count("n_jobs_wrapper") assert bool(count) == bool( n_jobs_count @@ -360,7 +361,7 @@ def n_jobs_check(text, estimator, method): def runtime_property_check(text, estimator, method): """use of Python's 'property' should not be used at runtime, only at class instantiation""" assert ( - len(re.findall(r"property\(", text[1])) == 0 + len(re.findall(r"property\(", text["trace"])) == 0 ), f"{estimator}.{method} should only use 'property' at instantiation"