From 94df8e90eca93ab63b11dded2e8b3faa98042ba9 Mon Sep 17 00:00:00 2001 From: Piotrek Date: Mon, 14 Sep 2020 22:22:25 +0200 Subject: [PATCH] add catch for exception in permutation importance (#185) --- examples/scripts/multi_class_classifier.py | 2 +- supervised/utils/importance.py | 63 +++++++++------------- 2 files changed, 26 insertions(+), 39 deletions(-) diff --git a/examples/scripts/multi_class_classifier.py b/examples/scripts/multi_class_classifier.py index a07615da..d58788e6 100644 --- a/examples/scripts/multi_class_classifier.py +++ b/examples/scripts/multi_class_classifier.py @@ -22,7 +22,7 @@ X = df[["feature_1", "feature_2", "feature_3", "feature_4"]] y = df["class"] -automl = AutoML() +automl = AutoML() automl.fit(X, y) diff --git a/supervised/utils/importance.py b/supervised/utils/importance.py index 9f6e2840..96e2ba43 100644 --- a/supervised/utils/importance.py +++ b/supervised/utils/importance.py @@ -48,43 +48,30 @@ def compute_and_plot( else: scoring = "neg_mean_squared_error" - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - importance = permutation_importance( - model, - X_validation, - y_validation, - scoring=scoring, - n_jobs=-1, # all cores - random_state=12, - n_repeats=5, # default - ) - - sorted_idx = importance["importances_mean"].argsort() + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + importance = permutation_importance( + model, + X_validation, + y_validation, + scoring=scoring, + n_jobs=-1, # all cores + random_state=12, + n_repeats=5, # default + ) - # save detailed importance - df_imp = pd.DataFrame( - { - "feature": X_validation.columns[sorted_idx], - "mean_importance": importance["importances_mean"][sorted_idx], - } - ) - df_imp.to_csv( - os.path.join(model_file_path, f"{learner_name}_importance.csv"), index=False - ) + sorted_idx = importance["importances_mean"].argsort() - """ - # Do not plot. We will plot aggregate of all importances. - # limit number of column in the plot - if len(sorted_idx) > 50: - sorted_idx = sorted_idx[-50:] - plt.figure(figsize=(10, 7)) - plt.barh( - X_validation.columns[sorted_idx], importance["importances_mean"][sorted_idx] - ) - plt.xlabel("Mean of feature importance") - plt.tight_layout(pad=2.0) - plot_path = os.path.join(model_file_path, f"{learner_name}_importance.png") - plt.savefig(plot_path) - plt.close("all") - """ + # save detailed importance + df_imp = pd.DataFrame( + { + "feature": X_validation.columns[sorted_idx], + "mean_importance": importance["importances_mean"][sorted_idx], + } + ) + df_imp.to_csv( + os.path.join(model_file_path, f"{learner_name}_importance.csv"), index=False + ) + except Exception as e: + print("Problem during computing permutation importance. Skipping ...") \ No newline at end of file