Skip to content

Commit

Permalink
refactor: clean up import statements and improve code formatting in t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
Lucas-CE committed Jan 4, 2025
1 parent 971a6e2 commit 99b4f6b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 57 deletions.
104 changes: 52 additions & 52 deletions tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
},
"outputs": [],
"source": [
"import matplotlib\n",
"import shap\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import plotly\n",
"import matplotlib.pyplot as plt\n",
"import shap\n",
"from datasets import concatenate_datasets\n",
"from sklearn.inspection import (\n",
" PartialDependenceDisplay,\n",
" partial_dependence,\n",
" permutation_importance, \n",
" PartialDependenceDisplay\n",
" permutation_importance,\n",
")\n",
"from sklearn.metrics import accuracy_score, balanced_accuracy_score, make_scorer\n",
"from datasets import concatenate_datasets"
"from sklearn.metrics import accuracy_score, make_scorer"
]
},
{
Expand All @@ -48,13 +47,13 @@
" load_dataset,\n",
" select_columns,\n",
")\n",
"from DashAI.back.tasks import TabularClassificationTask\n",
"from DashAI.back.models import KNeighborsClassifier, SVC\n",
"from DashAI.back.explainability import (\n",
" KernelShap, \n",
" PermutationFeatureImportance, \n",
" PartialDependence\n",
")"
" KernelShap,\n",
" PartialDependence,\n",
" PermutationFeatureImportance,\n",
")\n",
"from DashAI.back.models import SVC\n",
"from DashAI.back.tasks import TabularClassificationTask"
]
},
{
Expand Down Expand Up @@ -98,25 +97,23 @@
},
"outputs": [],
"source": [
"task= TabularClassificationTask()\n",
"task = TabularClassificationTask()\n",
"\n",
"cc_path = \"/home/naabarca/.DashAI/datasets/breast/dataset\"\n",
"cc_dataset = load_dataset(cc_path)\n",
"cc_inputs_columns = list(cc_dataset['train'].features)[:-1]\n",
"cc_outputs_columns = ['diagnosis']\n",
"cc_inputs_columns = list(cc_dataset[\"train\"].features)[:-1]\n",
"cc_outputs_columns = [\"diagnosis\"]\n",
"\n",
"cc_dataset = task.prepare_for_task(\n",
" datasetdict=cc_dataset,\n",
" outputs_columns=cc_outputs_columns\n",
" datasetdict=cc_dataset, outputs_columns=cc_outputs_columns\n",
")\n",
"cc_dataset = select_columns(cc_dataset, cc_inputs_columns, cc_outputs_columns)\n",
"\n",
"cc_path_i = \"/home/naabarca/.DashAI/datasets/breast_instances/dataset\"\n",
"cc_dataset_i = load_dataset(cc_path_i)\n",
"\n",
"cc_dataset_i = task.prepare_for_task(\n",
" datasetdict=cc_dataset_i,\n",
" outputs_columns=cc_outputs_columns\n",
" datasetdict=cc_dataset_i, outputs_columns=cc_outputs_columns\n",
")\n",
"cc_dataset_i = select_columns(cc_dataset_i, cc_inputs_columns, cc_outputs_columns)"
]
Expand Down Expand Up @@ -225,23 +222,23 @@
"source": [
"iris_path = \"/home/naabarca/.DashAI/datasets/iris/dataset\"\n",
"iris_dataset = load_dataset(iris_path)\n",
"iris_inputs_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalWidthCm']\n",
"iris_outputs_columns = ['Species']\n",
"iris_inputs_columns = [\"SepalLengthCm\", \"SepalWidthCm\", \"PetalWidthCm\"]\n",
"iris_outputs_columns = [\"Species\"]\n",
"\n",
"iris_dataset = task.prepare_for_task(\n",
" datasetdict=iris_dataset,\n",
" outputs_columns=iris_outputs_columns\n",
" datasetdict=iris_dataset, outputs_columns=iris_outputs_columns\n",
")\n",
"iris_dataset = select_columns(iris_dataset, iris_inputs_columns, iris_outputs_columns)\n",
"\n",
"iris_path_i = \"/home/naabarca/.DashAI/datasets/iris_instances/dataset\"\n",
"iris_dataset_i = load_dataset(iris_path_i)\n",
"\n",
"iris_dataset_i = task.prepare_for_task(\n",
" datasetdict=iris_dataset_i,\n",
" outputs_columns=iris_outputs_columns\n",
" datasetdict=iris_dataset_i, outputs_columns=iris_outputs_columns\n",
")\n",
"iris_dataset_i= select_columns(iris_dataset_i, iris_inputs_columns, iris_outputs_columns)"
"iris_dataset_i = select_columns(\n",
" iris_dataset_i, iris_inputs_columns, iris_outputs_columns\n",
")"
]
},
{
Expand Down Expand Up @@ -317,7 +314,7 @@
},
"outputs": [],
"source": [
"#model_class = KNeighborsClassifier()\n",
"# model_class = KNeighborsClassifier()\n",
"model_class = SVC()\n",
"model_path = f\"/home/naabarca/.DashAI/runs/{run_id}\"\n",
"model = model_class.load(model_path)"
Expand Down Expand Up @@ -2982,10 +2979,10 @@
],
"source": [
"parameters = {\n",
" \"grid_resolution\": 15,\n",
" \"lower_percentile\": 0.05,\n",
" \"upper_percentile\": 0.95,\n",
" }\n",
" \"grid_resolution\": 15,\n",
" \"lower_percentile\": 0.05,\n",
" \"upper_percentile\": 0.95,\n",
"}\n",
"explainer = PartialDependence(model, **parameters)\n",
"explanation = explainer.explain(cc_dataset)\n",
"plot = explainer.plot(explanation)\n",
Expand Down Expand Up @@ -3047,7 +3044,7 @@
}
],
"source": [
"explanation['area_worst']"
"explanation[\"area_worst\"]"
]
},
{
Expand All @@ -3065,7 +3062,7 @@
" feature_names=cc_inputs_columns,\n",
" percentiles=(0.05, 0.95),\n",
" grid_resolution=15,\n",
" kind=\"average\"\n",
" kind=\"average\",\n",
" )\n",
" print(cc_inputs_columns[idx])\n",
" print(pd)"
Expand Down Expand Up @@ -3108,13 +3105,13 @@
],
"source": [
"display = PartialDependenceDisplay.from_estimator(\n",
" estimator=model, \n",
" X=cc_dataset[0][\"test\"].to_pandas(), \n",
" features=[10], \n",
" # target=None, \n",
" kind='average',\n",
" estimator=model,\n",
" X=cc_dataset[0][\"test\"].to_pandas(),\n",
" features=[10],\n",
" # target=None,\n",
" kind=\"average\",\n",
" grid_resolution=15,\n",
" percentiles=(0.05, 0.95)\n",
" percentiles=(0.05, 0.95),\n",
")"
]
},
Expand Down Expand Up @@ -6155,16 +6152,17 @@
"def patched_metric(y_true, y_pred_probas):\n",
" return accuracy_score(y_true, np.argmax(y_pred_probas, axis=1))\n",
"\n",
"\n",
"pi = permutation_importance(\n",
" estimator=model, \n",
" X=cc_dataset[0][\"train\"].to_pandas(), \n",
" y=cc_dataset[1][\"train\"].to_pandas(), \n",
" estimator=model,\n",
" X=cc_dataset[0][\"train\"].to_pandas(),\n",
" y=cc_dataset[1][\"train\"].to_pandas(),\n",
" scoring=make_scorer(patched_metric),\n",
" n_repeats=20,\n",
" random_state=5,\n",
" max_samples=20\n",
" max_samples=20,\n",
")\n",
"plt.barh(range(30), pi['importances_mean'], color='b', align='center')"
"plt.barh(range(30), pi[\"importances_mean\"], color=\"b\", align=\"center\")"
]
},
{
Expand Down Expand Up @@ -7159,8 +7157,8 @@
],
"source": [
"parameters = {\n",
" \"link\": \"identity\",\n",
" }\n",
" \"link\": \"identity\",\n",
"}\n",
"fit_parameters = {\n",
" \"sample_background_data\": True,\n",
" \"n_background_samples\": 50,\n",
Expand Down Expand Up @@ -7208,7 +7206,7 @@
" model=model.predict_proba,\n",
" data=cc_dataset[0][\"train\"].to_pandas(),\n",
" feature_names=cc_inputs_columns,\n",
" link='identity'\n",
" link=\"identity\",\n",
")\n",
"\n",
"shap_values = kernel_shap.shap_values(X)"
Expand Down Expand Up @@ -7267,10 +7265,12 @@
],
"source": [
"clase = 1\n",
"explanation_shap = shap.Explanation(values=shap_values[clase][instance], \n",
" base_values=kernel_shap.expected_value[clase], \n",
" data=X.values,\n",
" feature_names=cc_inputs_columns)\n",
"explanation_shap = shap.Explanation(\n",
" values=shap_values[clase][instance],\n",
" base_values=kernel_shap.expected_value[clase],\n",
" data=X.values,\n",
" feature_names=cc_inputs_columns,\n",
")\n",
"shap.plots.waterfall(explanation_shap)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion tests/back/job_queue/test_simple_job_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_get_jobs(job_queue: BaseJobQueue):
assert jobs[0].id == job_2_id


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_async_get_job(job_queue: BaseJobQueue):
job_1 = DummyJob()
job_1_id = job_queue.put(job_1)
Expand Down
12 changes: 8 additions & 4 deletions tests/back/plugins/test_plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ def test_error_execute_pip_command():
subprocess_mock = Mock()
subprocess_mock.returncode = 1
subprocess_mock.stderr = "ERROR: ...\nERROR: ..."
with patch("subprocess.run", return_value=subprocess_mock), pytest.raises(
RuntimeError, match="ERROR: ...\nERROR: ..."
with (
patch("subprocess.run", return_value=subprocess_mock),
pytest.raises(RuntimeError, match="ERROR: ...\nERROR: ..."),
):
execute_pip_command("dashai-tabular-classification-package", "install")

Expand All @@ -234,8 +235,11 @@ def test_uninstall_plugin():
execute_pip_command_mock = Mock()
execute_pip_command_mock.return_value = 0

with patch("DashAI.back.plugins.utils.entry_points", entry_points_mock), patch(
"DashAI.back.plugins.utils.execute_pip_command", execute_pip_command_mock
with (
patch("DashAI.back.plugins.utils.entry_points", entry_points_mock),
patch(
"DashAI.back.plugins.utils.execute_pip_command", execute_pip_command_mock
),
):
uninsalled_plugins = uninstall_plugin("Plugin1")

Expand Down

0 comments on commit 99b4f6b

Please sign in to comment.