Skip to content

Commit

Permalink
add catch for exception in permutation importance (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Sep 14, 2020
1 parent de82ad2 commit 94df8e9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/multi_class_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
X = df[["feature_1", "feature_2", "feature_3", "feature_4"]]
y = df["class"]

automl = AutoML()
automl = AutoML()

automl.fit(X, y)

Expand Down
63 changes: 25 additions & 38 deletions supervised/utils/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,43 +48,30 @@ def compute_and_plot(
else:
scoring = "neg_mean_squared_error"

with warnings.catch_warnings():
warnings.simplefilter("ignore")
importance = permutation_importance(
model,
X_validation,
y_validation,
scoring=scoring,
n_jobs=-1, # all cores
random_state=12,
n_repeats=5, # default
)

sorted_idx = importance["importances_mean"].argsort()
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
importance = permutation_importance(
model,
X_validation,
y_validation,
scoring=scoring,
n_jobs=-1, # all cores
random_state=12,
n_repeats=5, # default
)

# save detailed importance
df_imp = pd.DataFrame(
{
"feature": X_validation.columns[sorted_idx],
"mean_importance": importance["importances_mean"][sorted_idx],
}
)
df_imp.to_csv(
os.path.join(model_file_path, f"{learner_name}_importance.csv"), index=False
)
sorted_idx = importance["importances_mean"].argsort()

"""
# Do not plot. We will plot aggregate of all importances.
# limit number of column in the plot
if len(sorted_idx) > 50:
sorted_idx = sorted_idx[-50:]
plt.figure(figsize=(10, 7))
plt.barh(
X_validation.columns[sorted_idx], importance["importances_mean"][sorted_idx]
)
plt.xlabel("Mean of feature importance")
plt.tight_layout(pad=2.0)
plot_path = os.path.join(model_file_path, f"{learner_name}_importance.png")
plt.savefig(plot_path)
plt.close("all")
"""
# save detailed importance
df_imp = pd.DataFrame(
{
"feature": X_validation.columns[sorted_idx],
"mean_importance": importance["importances_mean"][sorted_idx],
}
)
df_imp.to_csv(
os.path.join(model_file_path, f"{learner_name}_importance.csv"), index=False
)
except Exception as e:
print("Problem during computing permutation importance. Skipping ...")

0 comments on commit 94df8e9

Please sign in to comment.