Skip to content

Commit

Permalink
demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 22, 2024
1 parent 4175741 commit 63652cd
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sklego/meta/ordinal_classification.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn import clone
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, is_classifier
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, MultiOutputMixin, is_classifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.utils.validation import check_is_fitted, check_X_y


class OrdinalClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
r"""The `OrdinalClassifier` allows to use a binary classifier to address an ordinal classification problem.
Suppose we have N ordinal classes to predict, then the original binary classifier is fitted on N-1 by training sets,
Expand Down Expand Up @@ -74,7 +74,7 @@ class OrdinalClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""

def __init__(self, estimator, n_jobs=None, use_calibration=True):
def __init__(self, estimator, *, n_jobs=None, use_calibration=True):
self.estimator = estimator
self.n_jobs = n_jobs
self.use_calibration = use_calibration
Expand Down Expand Up @@ -104,7 +104,7 @@ def fit(self, X, y):
if not is_classifier(self.estimator):
raise ValueError("The estimator must be a classifier.")

if hasattr(self.estimator, "predict_proba"):
if not hasattr(self.estimator, "predict_proba"):
raise ValueError("The estimator must implement `.predict_proba()` method.")

X, y = check_X_y(X, y, estimator=self)
Expand Down
156 changes: 156 additions & 0 deletions tmp/ordinal_classification_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# ! pip install git+https://github.com/FBruzzesi/scikit-lego.git@ordinal-classification sklearn pandas xlrd"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, HistGradientBoostingClassifier\n",
"from sklearn.model_selection import cross_validate\n",
"from sklearn.multiclass import OneVsRestClassifier\n",
"\n",
"from sklearn.metrics import classification_report, balanced_accuracy_score, f1_score, make_scorer\n",
"\n",
"from sklearn import clone\n",
"\n",
"from sklego.meta import OrdinalClassifier"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"url = \"https://stats.idre.ucla.edu/stat/data/ologit.dta\"\n",
"\n",
"df = pd.read_stata(url).assign(apply = lambda t: t[\"apply\"].cat.codes)\n",
"\n",
"target = \"apply\"\n",
"features = [c for c in df.columns if c != target]\n",
"\n",
"X, y = df[features].to_numpy(), df[target].to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def score_estimator(estimator, X, y, scoring) -> pd.DataFrame:\n",
"\n",
" return (\n",
" pd.DataFrame(cross_validate(estimator, X, y, cv=10, scoring=scoring, n_jobs=-1))\n",
" .loc[:, [f\"test_{s}\" for s in scoring.keys()]]\n",
" .rename(columns={f\"test_{s}\": s for s in scoring.keys()})\n",
" )\n",
"\n",
"def compare_meta_models(base_estimator, X, y, scoring) -> pd.DataFrame:\n",
"\n",
" oc_estimator = OrdinalClassifier(clone(base_estimator), use_calibration=True, n_jobs=-1)\n",
" oc_scores = score_estimator(oc_estimator, X, y, scoring)\n",
" \n",
" ovr_estimator = OneVsRestClassifier(clone(base_estimator), n_jobs=-1)\n",
" ovr_scores = score_estimator(ovr_estimator, X, y, scoring)\n",
"\n",
" scores = pd.merge(oc_scores, ovr_scores, left_index=True, right_index=True, suffixes=[\"_oc\", \"_ovr\"])\n",
" return (scores.reindex(sorted(scores.columns), axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"scoring = {'accuracy': make_scorer(balanced_accuracy_score), \"f1\": make_scorer(f1_score, average=\"weighted\")}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LogisticRegression\n",
" accuracy_oc accuracy_ovr f1_oc f1_ovr\n",
"count 10.000000 10.000000 10.000000 10.000000\n",
"mean 0.379545 0.365584 0.496891 0.480228\n",
"std 0.027217 0.012178 0.030147 0.024780\n",
"min 0.350649 0.350649 0.447083 0.447083\n",
"50% 0.377706 0.366883 0.505263 0.486443\n",
"max 0.442641 0.389610 0.530615 0.517500\n",
"\n",
"ExtraTreesClassifier\n",
" accuracy_oc accuracy_ovr f1_oc f1_ovr\n",
"count 10.000000 10.000000 10.000000 10.000000\n",
"mean 0.376082 0.366342 0.482819 0.479117\n",
"std 0.052527 0.040000 0.049658 0.042576\n",
"min 0.296537 0.296537 0.405556 0.398889\n",
"50% 0.358225 0.358225 0.488406 0.480000\n",
"max 0.451299 0.449134 0.554097 0.532857\n",
"\n",
"HistGradientBoostingClassifier\n",
" accuracy_oc accuracy_ovr f1_oc f1_ovr\n",
"count 10.000000 10.000000 10.000000 10.000000\n",
"mean 0.319481 0.319264 0.410795 0.437852\n",
"std 0.055066 0.038031 0.047868 0.051737\n",
"min 0.244589 0.238095 0.352575 0.338216\n",
"50% 0.324675 0.325758 0.431075 0.447599\n",
"max 0.390693 0.376623 0.464660 0.512500\n",
"\n"
]
}
],
"source": [
"estimators = [LogisticRegression(), ExtraTreesClassifier(max_depth=5), HistGradientBoostingClassifier(max_depth=5)]\n",
"\n",
"for base_estimator in estimators:\n",
"\n",
" print(base_estimator.__class__.__name__)\n",
" scores = compare_meta_models(base_estimator, X, y, scoring)\n",
" print(scores.describe(percentiles=[0.5]))\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lego-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 63652cd

Please sign in to comment.