From 99b4f6bc3877876c430137790ffb02a5ab30dec7 Mon Sep 17 00:00:00 2001 From: Lucas-CE Date: Sat, 4 Jan 2025 11:34:03 -0300 Subject: [PATCH] refactor: clean up import statements and improve code formatting in tests --- tests.ipynb | 104 +++++++++--------- tests/back/job_queue/test_simple_job_queue.py | 2 +- tests/back/plugins/test_plugin_utils.py | 12 +- 3 files changed, 61 insertions(+), 57 deletions(-) diff --git a/tests.ipynb b/tests.ipynb index 91f019c06..5437d0451 100644 --- a/tests.ipynb +++ b/tests.ipynb @@ -15,18 +15,17 @@ }, "outputs": [], "source": [ - "import matplotlib\n", - "import shap\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly\n", - "import matplotlib.pyplot as plt\n", + "import shap\n", + "from datasets import concatenate_datasets\n", "from sklearn.inspection import (\n", + " PartialDependenceDisplay,\n", " partial_dependence,\n", - " permutation_importance, \n", - " PartialDependenceDisplay\n", + " permutation_importance,\n", ")\n", - "from sklearn.metrics import accuracy_score, balanced_accuracy_score, make_scorer\n", - "from datasets import concatenate_datasets" + "from sklearn.metrics import accuracy_score, make_scorer" ] }, { @@ -48,13 +47,13 @@ " load_dataset,\n", " select_columns,\n", ")\n", - "from DashAI.back.tasks import TabularClassificationTask\n", - "from DashAI.back.models import KNeighborsClassifier, SVC\n", "from DashAI.back.explainability import (\n", - " KernelShap, \n", - " PermutationFeatureImportance, \n", - " PartialDependence\n", - ")" + " KernelShap,\n", + " PartialDependence,\n", + " PermutationFeatureImportance,\n", + ")\n", + "from DashAI.back.models import SVC\n", + "from DashAI.back.tasks import TabularClassificationTask" ] }, { @@ -98,16 +97,15 @@ }, "outputs": [], "source": [ - "task= TabularClassificationTask()\n", + "task = TabularClassificationTask()\n", "\n", "cc_path = \"/home/naabarca/.DashAI/datasets/breast/dataset\"\n", "cc_dataset = load_dataset(cc_path)\n", - "cc_inputs_columns = list(cc_dataset['train'].features)[:-1]\n", - "cc_outputs_columns = ['diagnosis']\n", + "cc_inputs_columns = list(cc_dataset[\"train\"].features)[:-1]\n", + "cc_outputs_columns = [\"diagnosis\"]\n", "\n", "cc_dataset = task.prepare_for_task(\n", - " datasetdict=cc_dataset,\n", - " outputs_columns=cc_outputs_columns\n", + " datasetdict=cc_dataset, outputs_columns=cc_outputs_columns\n", ")\n", "cc_dataset = select_columns(cc_dataset, cc_inputs_columns, cc_outputs_columns)\n", "\n", @@ -115,8 +113,7 @@ "cc_dataset_i = load_dataset(cc_path_i)\n", "\n", "cc_dataset_i = task.prepare_for_task(\n", - " datasetdict=cc_dataset_i,\n", - " outputs_columns=cc_outputs_columns\n", + " datasetdict=cc_dataset_i, outputs_columns=cc_outputs_columns\n", ")\n", "cc_dataset_i = select_columns(cc_dataset_i, cc_inputs_columns, cc_outputs_columns)" ] @@ -225,12 +222,11 @@ "source": [ "iris_path = \"/home/naabarca/.DashAI/datasets/iris/dataset\"\n", "iris_dataset = load_dataset(iris_path)\n", - "iris_inputs_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalWidthCm']\n", - "iris_outputs_columns = ['Species']\n", + "iris_inputs_columns = [\"SepalLengthCm\", \"SepalWidthCm\", \"PetalWidthCm\"]\n", + "iris_outputs_columns = [\"Species\"]\n", "\n", "iris_dataset = task.prepare_for_task(\n", - " datasetdict=iris_dataset,\n", - " outputs_columns=iris_outputs_columns\n", + " datasetdict=iris_dataset, outputs_columns=iris_outputs_columns\n", ")\n", "iris_dataset = select_columns(iris_dataset, iris_inputs_columns, iris_outputs_columns)\n", "\n", @@ -238,10 +234,11 @@ "iris_dataset_i = load_dataset(iris_path_i)\n", "\n", "iris_dataset_i = task.prepare_for_task(\n", - " datasetdict=iris_dataset_i,\n", - " outputs_columns=iris_outputs_columns\n", + " datasetdict=iris_dataset_i, outputs_columns=iris_outputs_columns\n", ")\n", - "iris_dataset_i= select_columns(iris_dataset_i, iris_inputs_columns, iris_outputs_columns)" + "iris_dataset_i = select_columns(\n", + " iris_dataset_i, iris_inputs_columns, iris_outputs_columns\n", + ")" ] }, { @@ -317,7 +314,7 @@ }, "outputs": [], "source": [ - "#model_class = KNeighborsClassifier()\n", + "# model_class = KNeighborsClassifier()\n", "model_class = SVC()\n", "model_path = f\"/home/naabarca/.DashAI/runs/{run_id}\"\n", "model = model_class.load(model_path)" @@ -2982,10 +2979,10 @@ ], "source": [ "parameters = {\n", - " \"grid_resolution\": 15,\n", - " \"lower_percentile\": 0.05,\n", - " \"upper_percentile\": 0.95,\n", - " }\n", + " \"grid_resolution\": 15,\n", + " \"lower_percentile\": 0.05,\n", + " \"upper_percentile\": 0.95,\n", + "}\n", "explainer = PartialDependence(model, **parameters)\n", "explanation = explainer.explain(cc_dataset)\n", "plot = explainer.plot(explanation)\n", @@ -3047,7 +3044,7 @@ } ], "source": [ - "explanation['area_worst']" + "explanation[\"area_worst\"]" ] }, { @@ -3065,7 +3062,7 @@ " feature_names=cc_inputs_columns,\n", " percentiles=(0.05, 0.95),\n", " grid_resolution=15,\n", - " kind=\"average\"\n", + " kind=\"average\",\n", " )\n", " print(cc_inputs_columns[idx])\n", " print(pd)" @@ -3108,13 +3105,13 @@ ], "source": [ "display = PartialDependenceDisplay.from_estimator(\n", - " estimator=model, \n", - " X=cc_dataset[0][\"test\"].to_pandas(), \n", - " features=[10], \n", - " # target=None, \n", - " kind='average',\n", + " estimator=model,\n", + " X=cc_dataset[0][\"test\"].to_pandas(),\n", + " features=[10],\n", + " # target=None,\n", + " kind=\"average\",\n", " grid_resolution=15,\n", - " percentiles=(0.05, 0.95)\n", + " percentiles=(0.05, 0.95),\n", ")" ] }, @@ -6155,16 +6152,17 @@ "def patched_metric(y_true, y_pred_probas):\n", " return accuracy_score(y_true, np.argmax(y_pred_probas, axis=1))\n", "\n", + "\n", "pi = permutation_importance(\n", - " estimator=model, \n", - " X=cc_dataset[0][\"train\"].to_pandas(), \n", - " y=cc_dataset[1][\"train\"].to_pandas(), \n", + " estimator=model,\n", + " X=cc_dataset[0][\"train\"].to_pandas(),\n", + " y=cc_dataset[1][\"train\"].to_pandas(),\n", " scoring=make_scorer(patched_metric),\n", " n_repeats=20,\n", " random_state=5,\n", - " max_samples=20\n", + " max_samples=20,\n", ")\n", - "plt.barh(range(30), pi['importances_mean'], color='b', align='center')" + "plt.barh(range(30), pi[\"importances_mean\"], color=\"b\", align=\"center\")" ] }, { @@ -7159,8 +7157,8 @@ ], "source": [ "parameters = {\n", - " \"link\": \"identity\",\n", - " }\n", + " \"link\": \"identity\",\n", + "}\n", "fit_parameters = {\n", " \"sample_background_data\": True,\n", " \"n_background_samples\": 50,\n", @@ -7208,7 +7206,7 @@ " model=model.predict_proba,\n", " data=cc_dataset[0][\"train\"].to_pandas(),\n", " feature_names=cc_inputs_columns,\n", - " link='identity'\n", + " link=\"identity\",\n", ")\n", "\n", "shap_values = kernel_shap.shap_values(X)" @@ -7267,10 +7265,12 @@ ], "source": [ "clase = 1\n", - "explanation_shap = shap.Explanation(values=shap_values[clase][instance], \n", - " base_values=kernel_shap.expected_value[clase], \n", - " data=X.values,\n", - " feature_names=cc_inputs_columns)\n", + "explanation_shap = shap.Explanation(\n", + " values=shap_values[clase][instance],\n", + " base_values=kernel_shap.expected_value[clase],\n", + " data=X.values,\n", + " feature_names=cc_inputs_columns,\n", + ")\n", "shap.plots.waterfall(explanation_shap)" ] }, diff --git a/tests/back/job_queue/test_simple_job_queue.py b/tests/back/job_queue/test_simple_job_queue.py index fe2b35ac2..c965d625f 100644 --- a/tests/back/job_queue/test_simple_job_queue.py +++ b/tests/back/job_queue/test_simple_job_queue.py @@ -81,7 +81,7 @@ def test_get_jobs(job_queue: BaseJobQueue): assert jobs[0].id == job_2_id -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_async_get_job(job_queue: BaseJobQueue): job_1 = DummyJob() job_1_id = job_queue.put(job_1) diff --git a/tests/back/plugins/test_plugin_utils.py b/tests/back/plugins/test_plugin_utils.py index 8c5f3c110..978cb8ea2 100644 --- a/tests/back/plugins/test_plugin_utils.py +++ b/tests/back/plugins/test_plugin_utils.py @@ -206,8 +206,9 @@ def test_error_execute_pip_command(): subprocess_mock = Mock() subprocess_mock.returncode = 1 subprocess_mock.stderr = "ERROR: ...\nERROR: ..." - with patch("subprocess.run", return_value=subprocess_mock), pytest.raises( - RuntimeError, match="ERROR: ...\nERROR: ..." + with ( + patch("subprocess.run", return_value=subprocess_mock), + pytest.raises(RuntimeError, match="ERROR: ...\nERROR: ..."), ): execute_pip_command("dashai-tabular-classification-package", "install") @@ -234,8 +235,11 @@ def test_uninstall_plugin(): execute_pip_command_mock = Mock() execute_pip_command_mock.return_value = 0 - with patch("DashAI.back.plugins.utils.entry_points", entry_points_mock), patch( - "DashAI.back.plugins.utils.execute_pip_command", execute_pip_command_mock + with ( + patch("DashAI.back.plugins.utils.entry_points", entry_points_mock), + patch( + "DashAI.back.plugins.utils.execute_pip_command", execute_pip_command_mock + ), ): uninsalled_plugins = uninstall_plugin("Plugin1")