diff --git a/docs/_scripts/meta-models.py b/docs/_scripts/meta-models.py index 557273dbd..da473ab9c 100644 --- a/docs/_scripts/meta-models.py +++ b/docs/_scripts/meta-models.py @@ -67,9 +67,9 @@ plt.clf() # --8<-- [start:cross-validation-no-refit] -# %%time +# %%time -# Train an original model +# Train an original model orig_model = LogisticRegression(solver="lbfgs") orig_model.fit(X, y) @@ -111,7 +111,7 @@ def plot_model(model): df = load_chicken(as_frame=True) - + _ = model.fit(df[["diet", "time"]], df["weight"]) metric_df = (df[["diet", "time", "weight"]] .assign(pred=lambda d: model.predict(d[["diet", "time"]])) @@ -280,7 +280,7 @@ def plot_model(model): # --8<-- [start:decay-functions] -from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay +from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay fig = plt.figure(figsize=(12, 6)) @@ -312,13 +312,13 @@ def plot_model(model): np.random.seed(42) n1, n2, n3 = 100, 500, 50 -X = np.concatenate([np.random.normal(0, 1, (n1, 2)), +X = np.concatenate([np.random.normal(0, 1, (n1, 2)), np.random.normal(2, 1, (n2, 2)), - np.random.normal(3, 1, (n3, 2))], + np.random.normal(3, 1, (n3, 2))], axis=0) -y = np.concatenate([np.zeros((n1, 1)), +y = np.concatenate([np.zeros((n1, 1)), np.ones((n2, 1)), - np.zeros((n3, 1))], + np.zeros((n3, 1))], axis=0).reshape(-1) plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap); # --8<-- [end:make-blobs] @@ -360,7 +360,7 @@ def false_negatives(mod, x, y): cf_mod = ConfusionBalancer(LogisticRegression(solver="lbfgs", max_iter=1000), alpha=1.0) grid = GridSearchCV( - cf_mod, + cf_mod, param_grid={"alpha": np.linspace(-1.0, 3.0, 31)}, scoring={ "accuracy": make_scorer(accuracy_score), @@ -464,4 +464,49 @@ def false_negatives(mod, x, y): from sklearn.utils import estimator_html_repr with open(_static_path / "outlier-classifier-stacking.html", "w") as f: - f.write(estimator_html_repr(stacker)) \ No newline at end of file + f.write(estimator_html_repr(stacker)) + +# --8<-- [start:ordinal-classifier-data] +import pandas as pd + +url = "https://stats.idre.ucla.edu/stat/data/ologit.dta" +df = pd.read_stata(url).assign(apply_codes = lambda t: t["apply"].cat.codes) + +target = "apply_codes" +features = [c for c in df.columns if c not in {target, "apply"}] + +X, y = df[features].to_numpy(), df[target].to_numpy() +df.head() +# --8<-- [end:ordinal-classifier-data] + +with open(_static_path / "ordinal_data.md", "w") as f: + f.write(df.head().to_markdown(index=False)) + +# --8<-- [start:ordinal-classifier] +from sklearn.linear_model import LogisticRegression +from sklego.meta import OrdinalClassifier + +ord_clf = OrdinalClassifier(LogisticRegression(), n_jobs=-1, use_calibration=False) +_ = ord_clf.fit(X, y) +ord_clf.predict_proba(X[0]) +# --8<-- [end:ordinal-classifier] + +print(ord_clf.predict_proba(X[0])) + +# --8<-- [start:ordinal-classifier-with-calibration] +from sklearn.calibration import CalibratedClassifierCV +from sklearn.linear_model import LogisticRegression +from sklego.meta import OrdinalClassifier + +calibration_kwargs = {...} + +ord_clf = OrdinalClassifier( + estimator=LogisticRegression(), + use_calibration=True, + calibration_kwargs=calibration_kwargs +) + +# This is equivalent to: +estimator = CalibratedClassifierCV(LogisticRegression(), **calibration_kwargs) +ord_clf = OrdinalClassifier(estimator) +# --8<-- [end:ordinal-classifier-with-calibration] diff --git a/docs/_static/meta-models/ordinal_data.md b/docs/_static/meta-models/ordinal_data.md new file mode 100644 index 000000000..105883bbd --- /dev/null +++ b/docs/_static/meta-models/ordinal_data.md @@ -0,0 +1,7 @@ +| apply | pared | public | gpa | apply_codes | +|:----------------|--------:|---------:|------:|--------------:| +| very likely | 0 | 0 | 3.26 | 2 | +| somewhat likely | 1 | 0 | 3.21 | 1 | +| unlikely | 1 | 1 | 3.94 | 0 | +| somewhat likely | 0 | 0 | 2.81 | 1 | +| somewhat likely | 0 | 0 | 2.53 | 1 | diff --git a/docs/user-guide/meta-models.md b/docs/user-guide/meta-models.md index dc0815d7c..b357eeb10 100644 --- a/docs/user-guide/meta-models.md +++ b/docs/user-guide/meta-models.md @@ -100,7 +100,6 @@ The image below demonstrates what will happen. ![grouped](../_static/meta-models/grouped-df.png) - We train 5 models in total because the model will also train a fallback automatically (you can turn this off via `use_fallback=False`). The idea behind the fallback is that we can predict something if there is a group at prediction time which is unseen during training. @@ -291,6 +290,7 @@ We'll perform an optimistic demonstration below. ```py --8<-- "docs/_scripts/meta-models.py:confusion-balancer-results" ``` + It seems that we can pick a value for $\alpha$ such that the confusion matrix is balanced. there's also a modest increase in accuracy for this balancing moment. It should be emphasized though that this feature is **experimental**. There have been dataset/model combinations where this effect seems to work very well while there have also been situations where this trick does not work at all. @@ -327,7 +327,7 @@ ZIR (RFC+RFR) r²: 0.8992404366385873 RFR r²: 0.8516522752031502 ``` -## OutlierClassifier +## Outlier Classifier Outlier models are unsupervised so they don't have `predict_proba` or `score` methods. @@ -381,6 +381,62 @@ The `OutlierClassifier` can be combined with any classification model in the `St --8<-- "docs/_static/meta-models/outlier-classifier-stacking.html" +## Ordinal Classification + +Ordinal classification (sometimes also referred to as Ordinal Refression) involves predicting an ordinal target variable, where the classes have a meaningful order. +Examples of this kind of problem are: predicting customer satisfaction on a scale from 1 to 5, predicting the severity of a disease, predicting the quality of a product, etc. + +The [`OrdinalClassifier`][ordinal-classifier-api] is a meta-model that can be used to transform any classifier into an ordinal classifier by fitting N-1 binary classifiers, each handling a specific class boundary, namely: $P(y <= 1), P(y <= 2), ..., P(y <= N-1)$. + +This implementation is based on the paper [A simple approach to ordinal classification][ordinal-classification-paper] and it allows to predict the ordinal probabilities of each sample belonging to a particular class. + +!!! note "mord library" + If you are looking for a library that implements other ordinal classification algorithms, you can have a look at the [mord][mord] library. + +```py title="Ordinal Data" +--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-data" +``` + +--8<-- "docs/_static/meta-models/ordinal_data.md" + +Description of the dataset from [statsmodels tutorial][statsmodels-ordinal-regression]: + +> This dataset is about the probability for undergraduate students to apply to graduate school given three exogenous variables: +> +> - their grade point average (`gpa`), a float between 0 and 4. +> - `pared`, a binary that indicates if at least one parent went to graduate school. +> - `public`, a binary that indicates if the current undergraduate institution of the student is > public or private. +> +> `apply`, the target variable is categorical with ordered categories: "unlikely" < "somewhat likely" < "very likely". +> +> [...] +> +> For more details see the the Documentation of OrderedModel, [the UCLA webpage][ucla-webpage]. + +The only transformation we are applying to the data is to convert the target variable to an ordinal categorical variable by mapping the ordered categories to integers using their (pandas) category codes. + +We are now ready to train a [`OrdinalClassifier`][ordinal-classifier-api] on this dataset: + +```py title="OrdinalClassifier" +--8<-- "docs/_scripts/meta-models.py:ordinal-classifier" +``` + +> [[0.54883853 0.36225347 0.088908]] + +### Probability Calibration + +The `OrdinalClassifier` emphasizes the importance of proper probability estimates for its functionality. It is recommended to use the [`CalibratedClassifierCV`][calibrated-classifier-api] class from scikit-learn to calibrate the probabilities of the binary classifiers. + +Probability calibration is _not_ enabled by default, but we provide a convenient keyword argument `use_calibration` to enable it as follows: + +```py title="OrdinalClassifier with probability calibration" +--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-with-calibration" +``` + +### Computation Time + +As a meta-estimator, the `OrdinalClassifier` fits N-1 binary classifiers, which may be computationally expensive, especially with a large number of samples, features, or a complex classifier. + [thresholder-api]: ../../api/meta#sklego.meta.thresholder.Thresholder [grouped-predictor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedPredictor [grouped-transformer-api]: ../../api/meta#sklego.meta.grouped_transformer.GroupedTransformer @@ -389,8 +445,14 @@ The `OutlierClassifier` can be combined with any classification model in the `St [confusion-balancer-api]: ../../api/meta#sklego.meta.confusion_balancer.ConfusionBalancer [zero-inflated-api]: ../../api/meta#sklego.meta.zero_inflated_regressor.ZeroInflatedRegressor [outlier-classifier-api]: ../../api/meta#sklego.meta.outlier_classifier.OutlierClassifier +[ordinal-classifier-api]: ../../api/meta#sklego.meta.ordinal_classification.OrdinalClassifier [standard-scaler-api]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html [stacking-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingClassifier.html#sklearn.ensemble.StackingClassifier [dummy-regressor-api]: https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html [imb-learn]: https://imbalanced-learn.org/stable/ +[ordinal-classification-paper]: https://www.cs.waikato.ac.nz/~eibe/pubs/ordinal_tech_report.pdf +[mord]: https://pythonhosted.org/mord/ +[statsmodels-ordinal-regression]: https://www.statsmodels.org/dev/examples/notebooks/generated/ordinal_regression.html +[ucla-webpage]: https://stats.oarc.ucla.edu/r/dae/ordinal-logistic-regression/ +[calibrated-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html diff --git a/pyproject.toml b/pyproject.toml index 62505fafc..b6a9ca67a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ sklego = ["data/*.zip"] [tool.ruff] line-length = 120 extend-select = ["I"] +exclude = ["docs"] [tool.pytest.ini_options] markers = [ diff --git a/sklego/meta/ordinal_classification.py b/sklego/meta/ordinal_classification.py index 2466a5ee4..1271750e0 100644 --- a/sklego/meta/ordinal_classification.py +++ b/sklego/meta/ordinal_classification.py @@ -3,7 +3,8 @@ from sklearn import clone from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, MultiOutputMixin, is_classifier from sklearn.calibration import CalibratedClassifierCV -from sklearn.utils.validation import check_is_fitted, check_X_y +from sklearn.metrics import accuracy_score +from sklearn.utils.validation import check_array, check_is_fitted, check_X_y class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, BaseEstimator): @@ -26,8 +27,9 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B more like a confidence score). We recommend to use `CalibratedClassifierCV` to calibrate the probabilities of the binary classifiers. - This is enabled by default, but can be disabled by setting `use_calibration=False` and passing a calibrated - classifier to the `OrdinalClassifier` constructor. + + You can enable this by setting `use_calibration=True` and passing an uncalibrated classifier to the + `OrdinalClassifier` or by passing a calibrated classifier to the `OrdinalClassifier` constructor. More on this topic can be found in the [scikit-learn documentation](https://scikit-learn.org/stable/modules/calibration.html). @@ -41,9 +43,16 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B estimator : scikit-learn compatible classifier The estimator to be applied to the data, used as binary classifier. n_jobs : int, default=None - The number of jobs to run in parallel. `None` means 1, `-1` means using all processors. - use_calibration : bool, default=True + The number of jobs to run in parallel. The same convention of [`joblib.Parallel`](https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html) + holds: + + - `n_jobs = None`: interpreted as n_jobs=1. + - `n_jobs > 0`: n_cpus=n_jobs are used. + - `n_jobs < 0`: (n_cpus + 1 + n_jobs) are used. + use_calibration : bool, default=False Whether or not to calibrate the binary classifiers using `CalibratedClassifierCV`. + calibrarion_kwargs : dict | None, default=None + Keyword arguments to the `CalibratedClassifierCV` class, used only if `use_calibration=True`. Attributes ---------- @@ -57,12 +66,23 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B Examples -------- ```py - from sklego.meta import OrdinalClassifier + import pandas as pd + from sklearn.linear_model import LogisticRegression + from sklearn.model_selection import train_test_split + + from sklego.meta import OrdinalClassifier - ... + url = "https://stats.idre.ucla.edu/stat/data/ologit.dta" + df = pd.read_stata(url).assign(apply_codes = lambda t: t["apply"].cat.codes) - clf = OrdinalClassifier(LogisticRegression()) + target = "apply_codes" + features = [c for c in df.columns if c not in {target, "apply"}] + + X, y = df[features].to_numpy(), df[target].to_numpy() + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + clf = OrdinalClassifier(LogisticRegression(), n_jobs=-1) _ = clf.fit(X_train, y_train) clf.predict_proba(X_test) ``` @@ -74,10 +94,11 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B """ - def __init__(self, estimator, *, n_jobs=None, use_calibration=True): + def __init__(self, estimator, *, n_jobs=None, use_calibration=False, **calibrarion_kwargs): self.estimator = estimator self.n_jobs = n_jobs self.use_calibration = use_calibration + self.calibrarion_kwargs = calibrarion_kwargs def fit(self, X, y): """Fit the `OrdinalClassifier` model on training data `X` and `y` by fitting its underlying estimators on @@ -112,6 +133,9 @@ def fit(self, X, y): self.classes_ = np.sort(np.unique(y)) self.n_features_in_ = X.shape[1] + if self.n_classes_ < 2: + raise ValueError("Classifier can't train when only one class is present.") + if self.n_jobs is None or self.n_jobs == 1: self.estimators_ = {y_label: self._fit_binary_estimator(X, y, y_label) for y_label in self.classes_[:-1]} else: @@ -123,6 +147,7 @@ def fit(self, X, y): ), ) ) + return self def predict_proba(self, X): @@ -146,6 +171,7 @@ def predict_proba(self, X): If `X` has a different number of features than the one seen during `fit`. """ check_is_fitted(self, ["estimators_", "classes_"]) + X = check_array(X, ensure_2d=True, estimator=self) if X.shape[1] != self.n_features_in_: raise ValueError(f"X has {X.shape[1]} features, expected {self.n_features_in_} features.") @@ -158,6 +184,7 @@ def predict_proba(self, X): def predict(self, X): """Predict class labels for samples in `X` as the class with the highest probability.""" + check_is_fitted(self, ["estimators_", "classes_"]) return self.classes_[np.argmax(self.predict_proba(X), axis=1)] def _fit_binary_estimator(self, X, y, y_label): @@ -178,8 +205,29 @@ def _fit_binary_estimator(self, X, y, y_label): The fitted binary classifier. """ y_bin = (y <= y_label).astype(int) - fitted_model = clone(self.estimator).fit(X, y_bin) if self.use_calibration: - return CalibratedClassifierCV(fitted_model, cv="prefit").fit(X, y_bin) + return CalibratedClassifierCV(estimator=clone(self.estimator), **self.calibrarion_kwargs).fit(X, y_bin) else: - return fitted_model + return clone(self.estimator).fit(X, y_bin) + + def score(self, X, y): + """Returns the accuracy score on the given test data and labels. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features ) + The training data. + y : array-like of shape (n_samples,) + The target values. + + Returns + ------- + score : float + Accuracy score of self.predict(X) wrt. y. + """ + return accuracy_score(y, self.predict(X)) + + @property + def n_classes_(self): + """Number of classes.""" + return len(self.classes_) diff --git a/tests/test_meta/test_grouped_transformer.py b/tests/test_meta/test_grouped_transformer.py index 3fad4af7d..ef33a6a9c 100644 --- a/tests/test_meta/test_grouped_transformer.py +++ b/tests/test_meta/test_grouped_transformer.py @@ -187,7 +187,7 @@ def test_multiple_grouping_columns(dataset_with_multiple_grouping, scaling_range assert np.allclose(df_with_groups.groupby(["A", "B"]).min(), scaling_range[0]) - # If a group has a single element, it defaults to min, so check wether all maxes are one of the bounds + # If a group has a single element, it defaults to min, so check whether all maxes are one of the bounds maxes = df_with_groups.groupby(["A", "B"]).max() assert np.all( np.isclose(maxes, scaling_range[1]) | np.isclose(maxes, scaling_range[0]) diff --git a/tests/test_meta/test_ordinal_classification.py b/tests/test_meta/test_ordinal_classification.py new file mode 100644 index 000000000..bfc1bd7f6 --- /dev/null +++ b/tests/test_meta/test_ordinal_classification.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest +from sklearn.linear_model import LinearRegression, LogisticRegression, RidgeClassifier + +from sklego.common import flatten +from sklego.meta import OrdinalClassifier +from tests.conftest import classifier_checks, general_checks, select_tests + + +@pytest.fixture +def random_xy_ordinal(): + np.random.seed(42) + X = np.random.normal(0, 2, (1000, 3)) + y = np.select(condlist=[X[:, 0] < 2, X[:, 1] > 2], choicelist=[0, 2], default=1) + return X, y + + +@pytest.mark.parametrize("test_fn", select_tests(flatten([general_checks, classifier_checks]))) +def test_estimator_checks(test_fn): + ord_clf = OrdinalClassifier(estimator=LogisticRegression()) + test_fn("OrdinalClassifier", ord_clf) + + +@pytest.mark.parametrize( + "estimator, context, err_msg", + [ + (LinearRegression(), pytest.raises(ValueError), "The estimator must be a classifier."), + (RidgeClassifier(), pytest.raises(ValueError), "The estimator must implement `.predict_proba()` method."), + ], +) +def test_raises_error(random_xy_ordinal, estimator, context, err_msg): + X, y = random_xy_ordinal + with context as exc_info: + ord_clf = OrdinalClassifier(estimator=estimator) + ord_clf.fit(X, y) + + if exc_info: + assert err_msg in str(exc_info.value) + + +@pytest.mark.parametrize("n_jobs", [-2, -1, 2, None]) +@pytest.mark.parametrize("use_calibration", [True, False]) +def test_can_fit_param_combination(random_xy_ordinal, n_jobs, use_calibration): + X, y = random_xy_ordinal + ord_clf = OrdinalClassifier(estimator=LogisticRegression(), n_jobs=n_jobs, use_calibration=use_calibration) + _ = ord_clf.fit(X, y) + + assert ord_clf.n_jobs == n_jobs + assert ord_clf.use_calibration == use_calibration + assert ord_clf.n_classes_ == 3 + assert ord_clf.n_features_in_ == X.shape[1] diff --git a/tmp/ordinal_classification_demo.ipynb b/tmp/ordinal_classification_demo.ipynb index ded0bc7c3..c5cd341ab 100644 --- a/tmp/ordinal_classification_demo.ipynb +++ b/tmp/ordinal_classification_demo.ipynb @@ -2,16 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "# ! pip install git+https://github.com/FBruzzesi/scikit-lego.git@ordinal-classification sklearn pandas xlrd" + "# ! pip install git+https://github.com/FBruzzesi/scikit-lego.git@ordinal-classification" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -20,6 +20,7 @@ "\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, HistGradientBoostingClassifier\n", + "from sklearn.calibration import CalibratedClassifierCV\n", "from sklearn.model_selection import cross_validate\n", "from sklearn.multiclass import OneVsRestClassifier\n", "\n", @@ -32,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -48,14 +49,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def score_estimator(estimator, X, y, scoring) -> pd.DataFrame:\n", "\n", " return (\n", - " pd.DataFrame(cross_validate(estimator, X, y, cv=10, scoring=scoring, n_jobs=-1))\n", + " pd.DataFrame(cross_validate(estimator, X, y, cv=5, scoring=scoring, n_jobs=-1))\n", " .loc[:, [f\"test_{s}\" for s in scoring.keys()]]\n", " .rename(columns={f\"test_{s}\": s for s in scoring.keys()})\n", " )\n", @@ -69,12 +70,56 @@ " ovr_scores = score_estimator(ovr_estimator, X, y, scoring)\n", "\n", " scores = pd.merge(oc_scores, ovr_scores, left_index=True, right_index=True, suffixes=[\"_oc\", \"_ovr\"])\n", - " return (scores.reindex(sorted(scores.columns), axis=1))" + " return (scores.reindex(sorted(scores.columns), axis=1))\n", + "\n", + "def fit_predict_meta_models(base_estimator, X, y):\n", + "\n", + " oc_cal_estimator = OrdinalClassifier(clone(base_estimator), use_calibration=True, n_jobs=-1)\n", + " oc_cal_preds = oc_cal_estimator.fit(X, y).predict_proba(X)\n", + "\n", + " oc_no_cal_estimator = OrdinalClassifier(clone(base_estimator), use_calibration=False, n_jobs=-1)\n", + " oc_no_cal_preds = oc_no_cal_estimator.fit(X, y).predict_proba(X)\n", + "\n", + " ovr_estimator = OneVsRestClassifier(clone(base_estimator), n_jobs=-1)\n", + " ovr_preds = ovr_estimator.fit(X, y).predict_proba(X)\n", + "\n", + " return oc_cal_preds, oc_no_cal_preds, ovr_preds\n", + "\n", + "def is_monotonic(row, split_idx):\n", + " left, right = np.split(row, [split_idx])\n", + " is_monotonic = (np.diff(left)>0).all() & (np.diff(right)<0).all()\n", + " return is_monotonic\n", + "\n", + "def check_monotonicity(arr):\n", + " argmax = np.argmax(arr, axis=1)\n", + " \n", + " return np.array([is_monotonic(row, split_idx) for row, split_idx in zip(arr, argmax)])\n" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, False, False])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = np.array([[.1, .2, .3], [.3, .2, .1], [.1, .2, .1], [.3, .1, .2], [.2, .1, .3]])\n", + "check_monotonicity(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -83,39 +128,45 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "LogisticRegression\n", - " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", - "count 10.000000 10.000000 10.000000 10.000000\n", - "mean 0.379545 0.365584 0.496891 0.480228\n", - "std 0.027217 0.012178 0.030147 0.024780\n", - "min 0.350649 0.350649 0.447083 0.447083\n", - "50% 0.377706 0.366883 0.505263 0.486443\n", - "max 0.442641 0.389610 0.530615 0.517500\n", + "LogisticRegression\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", + "count 5.000000 5.000000 5.000000 5.000000\n", + "mean 0.368831 0.364069 0.487874 0.481174\n", + "std 0.010307 0.007482 0.022214 0.016987\n", + "min 0.353896 0.353896 0.453290 0.453290\n", + "50% 0.374459 0.362554 0.500000 0.483333\n", + "max 0.378788 0.374459 0.506667 0.498835\n", "\n", "ExtraTreesClassifier\n", - " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", - "count 10.000000 10.000000 10.000000 10.000000\n", - "mean 0.376082 0.366342 0.482819 0.479117\n", - "std 0.052527 0.040000 0.049658 0.042576\n", - "min 0.296537 0.296537 0.405556 0.398889\n", - "50% 0.358225 0.358225 0.488406 0.480000\n", - "max 0.451299 0.449134 0.554097 0.532857\n", + " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", + "count 5.000000 5.000000 5.000000 5.000000\n", + "mean 0.353680 0.377056 0.462485 0.490132\n", + "std 0.014867 0.038643 0.025648 0.041356\n", + "min 0.331169 0.327922 0.433073 0.445189\n", + "50% 0.353896 0.374459 0.459792 0.501344\n", + "max 0.371212 0.424784 0.502998 0.546011\n", "\n", "HistGradientBoostingClassifier\n", - " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", - "count 10.000000 10.000000 10.000000 10.000000\n", - "mean 0.319481 0.319264 0.410795 0.437852\n", - "std 0.055066 0.038031 0.047868 0.051737\n", - "min 0.244589 0.238095 0.352575 0.338216\n", - "50% 0.324675 0.325758 0.431075 0.447599\n", - "max 0.390693 0.376623 0.464660 0.512500\n", + " accuracy_oc accuracy_ovr f1_oc f1_ovr\n", + "count 5.000000 5.000000 5.000000 5.000000\n", + "mean 0.332900 0.346861 0.399589 0.453681\n", + "std 0.004630 0.074274 0.020394 0.075310\n", + "min 0.325758 0.228355 0.390323 0.326167\n", + "50% 0.333333 0.356602 0.390323 0.463303\n", + "max 0.338745 0.431818 0.436068 0.514889\n", "\n" ] } @@ -130,6 +181,58 @@ " print(scores.describe(percentiles=[0.5]))\n", " print()" ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LogisticRegression\n", + "Percentage of monotonic constraint respect:\n", + "\n", + "\tOrdinalClassifier with calibration: 100.00%\n", + "\tOrdinalClassifier without calibration: 100.00%\n", + "\tOVR Classifier: 100.00%\n", + "\n", + "ExtraTreesClassifier\n", + "Percentage of monotonic constraint respect:\n", + "\n", + "\tOrdinalClassifier with calibration: 100.00%\n", + "\tOrdinalClassifier without calibration: 94.50%\n", + "\tOVR Classifier: 94.75%\n", + "\n", + "HistGradientBoostingClassifier\n", + "Percentage of monotonic constraint respect:\n", + "\n", + "\tOrdinalClassifier with calibration: 100.00%\n", + "\tOrdinalClassifier without calibration: 89.25%\n", + "\tOVR Classifier: 89.25%\n", + "\n" + ] + } + ], + "source": [ + "estimators = [LogisticRegression(), ExtraTreesClassifier(max_depth=5), HistGradientBoostingClassifier(max_depth=5)]\n", + "\n", + "for base_estimator in estimators:\n", + " print(base_estimator.__class__.__name__)\n", + " oc_cal_preds, oc_no_cal_preds, ovr_preds = fit_predict_meta_models(base_estimator, X, y)\n", + "\n", + "\n", + " print(\"Percentage of monotonic constraint respect:\\n\")\n", + " print(\n", + " f\"\\tOrdinalClassifier with calibration: {100*check_monotonicity(oc_cal_preds).mean():.2f}%\",\n", + " f\"\\tOrdinalClassifier without calibration: {100*check_monotonicity(oc_no_cal_preds).mean():.2f}%\",\n", + " f\"\\tOVR Classifier: {100*check_monotonicity(ovr_preds).mean():.2f}%\",\n", + " sep=\"\\n\"\n", + " )\n", + "\n", + " print()" + ] } ], "metadata": {