Skip to content

Commit

Permalink
tests,docs,api change
Browse files Browse the repository at this point in the history
* unit tests

* docs

* change calibration api
  • Loading branch information
FBruzzesi authored Jan 26, 2024
1 parent 63652cd commit 95cfd85
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 56 deletions.
65 changes: 55 additions & 10 deletions docs/_scripts/meta-models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@
plt.clf()

# --8<-- [start:cross-validation-no-refit]
# %%time
# %%time

# Train an original model
# Train an original model
orig_model = LogisticRegression(solver="lbfgs")
orig_model.fit(X, y)

Expand Down Expand Up @@ -111,7 +111,7 @@

def plot_model(model):
df = load_chicken(as_frame=True)

_ = model.fit(df[["diet", "time"]], df["weight"])
metric_df = (df[["diet", "time", "weight"]]
.assign(pred=lambda d: model.predict(d[["diet", "time"]]))
Expand Down Expand Up @@ -280,7 +280,7 @@ def plot_model(model):


# --8<-- [start:decay-functions]
from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay
from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay

fig = plt.figure(figsize=(12, 6))

Expand Down Expand Up @@ -312,13 +312,13 @@ def plot_model(model):
np.random.seed(42)

n1, n2, n3 = 100, 500, 50
X = np.concatenate([np.random.normal(0, 1, (n1, 2)),
X = np.concatenate([np.random.normal(0, 1, (n1, 2)),
np.random.normal(2, 1, (n2, 2)),
np.random.normal(3, 1, (n3, 2))],
np.random.normal(3, 1, (n3, 2))],
axis=0)
y = np.concatenate([np.zeros((n1, 1)),
y = np.concatenate([np.zeros((n1, 1)),
np.ones((n2, 1)),
np.zeros((n3, 1))],
np.zeros((n3, 1))],
axis=0).reshape(-1)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap);
# --8<-- [end:make-blobs]
Expand Down Expand Up @@ -360,7 +360,7 @@ def false_negatives(mod, x, y):
cf_mod = ConfusionBalancer(LogisticRegression(solver="lbfgs", max_iter=1000), alpha=1.0)

grid = GridSearchCV(
cf_mod,
cf_mod,
param_grid={"alpha": np.linspace(-1.0, 3.0, 31)},
scoring={
"accuracy": make_scorer(accuracy_score),
Expand Down Expand Up @@ -464,4 +464,49 @@ def false_negatives(mod, x, y):

from sklearn.utils import estimator_html_repr
with open(_static_path / "outlier-classifier-stacking.html", "w") as f:
f.write(estimator_html_repr(stacker))
f.write(estimator_html_repr(stacker))

# --8<-- [start:ordinal-classifier-data]
import pandas as pd

url = "https://stats.idre.ucla.edu/stat/data/ologit.dta"
df = pd.read_stata(url).assign(apply_codes = lambda t: t["apply"].cat.codes)

target = "apply_codes"
features = [c for c in df.columns if c not in {target, "apply"}]

X, y = df[features].to_numpy(), df[target].to_numpy()
df.head()
# --8<-- [end:ordinal-classifier-data]

with open(_static_path / "ordinal_data.md", "w") as f:
f.write(df.head().to_markdown(index=False))

# --8<-- [start:ordinal-classifier]
from sklearn.linear_model import LogisticRegression
from sklego.meta import OrdinalClassifier

ord_clf = OrdinalClassifier(LogisticRegression(), n_jobs=-1, use_calibration=False)
_ = ord_clf.fit(X, y)
ord_clf.predict_proba(X[0])
# --8<-- [end:ordinal-classifier]

print(ord_clf.predict_proba(X[0]))

# --8<-- [start:ordinal-classifier-with-calibration]
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklego.meta import OrdinalClassifier

calibration_kwargs = {...}

ord_clf = OrdinalClassifier(
estimator=LogisticRegression(),
use_calibration=True,
calibration_kwargs=calibration_kwargs
)

# This is equivalent to:
estimator = CalibratedClassifierCV(LogisticRegression(), **calibration_kwargs)
ord_clf = OrdinalClassifier(estimator)
# --8<-- [end:ordinal-classifier-with-calibration]
7 changes: 7 additions & 0 deletions docs/_static/meta-models/ordinal_data.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
| apply | pared | public | gpa | apply_codes |
|:----------------|--------:|---------:|------:|--------------:|
| very likely | 0 | 0 | 3.26 | 2 |
| somewhat likely | 1 | 0 | 3.21 | 1 |
| unlikely | 1 | 1 | 3.94 | 0 |
| somewhat likely | 0 | 0 | 2.81 | 1 |
| somewhat likely | 0 | 0 | 2.53 | 1 |
66 changes: 64 additions & 2 deletions docs/user-guide/meta-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ The image below demonstrates what will happen.

![grouped](../_static/meta-models/grouped-df.png)


We train 5 models in total because the model will also train a fallback automatically (you can turn this off via `use_fallback=False`).

The idea behind the fallback is that we can predict something if there is a group at prediction time which is unseen during training.
Expand Down Expand Up @@ -291,6 +290,7 @@ We'll perform an optimistic demonstration below.
```py
--8<-- "docs/_scripts/meta-models.py:confusion-balancer-results"
```

It seems that we can pick a value for $\alpha$ such that the confusion matrix is balanced. there's also a modest increase in accuracy for this balancing moment.

It should be emphasized though that this feature is **experimental**. There have been dataset/model combinations where this effect seems to work very well while there have also been situations where this trick does not work at all.
Expand Down Expand Up @@ -327,7 +327,7 @@ ZIR (RFC+RFR) r²: 0.8992404366385873
RFR r²: 0.8516522752031502
```

## OutlierClassifier
## Outlier Classifier

Outlier models are unsupervised so they don't have `predict_proba` or `score` methods.

Expand Down Expand Up @@ -381,6 +381,62 @@ The `OutlierClassifier` can be combined with any classification model in the `St

--8<-- "docs/_static/meta-models/outlier-classifier-stacking.html"

## Ordinal Classification

Ordinal classification (sometimes also referred to as Ordinal Refression) involves predicting an ordinal target variable, where the classes have a meaningful order.
Examples of this kind of problem are: predicting customer satisfaction on a scale from 1 to 5, predicting the severity of a disease, predicting the quality of a product, etc.

The [`OrdinalClassifier`][ordinal-classifier-api] is a meta-model that can be used to transform any classifier into an ordinal classifier by fitting N-1 binary classifiers, each handling a specific class boundary, namely: $P(y <= 1), P(y <= 2), ..., P(y <= N-1)$.

This implementation is based on the paper [A simple approach to ordinal classification][ordinal-classification-paper] and it allows to predict the ordinal probabilities of each sample belonging to a particular class.

!!! note "mord library"
If you are looking for a library that implements other ordinal classification algorithms, you can have a look at the [mord][mord] library.

```py title="Ordinal Data"
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-data"
```

--8<-- "docs/_static/meta-models/ordinal_data.md"

Description of the dataset from [statsmodels tutorial][statsmodels-ordinal-regression]:

> This dataset is about the probability for undergraduate students to apply to graduate school given three exogenous variables:
>
> - their grade point average (`gpa`), a float between 0 and 4.
> - `pared`, a binary that indicates if at least one parent went to graduate school.
> - `public`, a binary that indicates if the current undergraduate institution of the student is > public or private.
>
> `apply`, the target variable is categorical with ordered categories: "unlikely" < "somewhat likely" < "very likely".
>
> [...]
>
> For more details see the the Documentation of OrderedModel, [the UCLA webpage][ucla-webpage].
The only transformation we are applying to the data is to convert the target variable to an ordinal categorical variable by mapping the ordered categories to integers using their (pandas) category codes.

We are now ready to train a [`OrdinalClassifier`][ordinal-classifier-api] on this dataset:

```py title="OrdinalClassifier"
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier"
```

> [[0.54883853 0.36225347 0.088908]]
### Probability Calibration

The `OrdinalClassifier` emphasizes the importance of proper probability estimates for its functionality. It is recommended to use the [`CalibratedClassifierCV`][calibrated-classifier-api] class from scikit-learn to calibrate the probabilities of the binary classifiers.

Probability calibration is _not_ enabled by default, but we provide a convenient keyword argument `use_calibration` to enable it as follows:

```py title="OrdinalClassifier with probability calibration"
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-with-calibration"
```

### Computation Time

As a meta-estimator, the `OrdinalClassifier` fits N-1 binary classifiers, which may be computationally expensive, especially with a large number of samples, features, or a complex classifier.

[thresholder-api]: ../../api/meta#sklego.meta.thresholder.Thresholder
[grouped-predictor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedPredictor
[grouped-transformer-api]: ../../api/meta#sklego.meta.grouped_transformer.GroupedTransformer
Expand All @@ -389,8 +445,14 @@ The `OutlierClassifier` can be combined with any classification model in the `St
[confusion-balancer-api]: ../../api/meta#sklego.meta.confusion_balancer.ConfusionBalancer
[zero-inflated-api]: ../../api/meta#sklego.meta.zero_inflated_regressor.ZeroInflatedRegressor
[outlier-classifier-api]: ../../api/meta#sklego.meta.outlier_classifier.OutlierClassifier
[ordinal-classifier-api]: ../../api/meta#sklego.meta.ordinal_classification.OrdinalClassifier

[standard-scaler-api]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
[stacking-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingClassifier.html#sklearn.ensemble.StackingClassifier
[dummy-regressor-api]: https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html
[imb-learn]: https://imbalanced-learn.org/stable/
[ordinal-classification-paper]: https://www.cs.waikato.ac.nz/~eibe/pubs/ordinal_tech_report.pdf
[mord]: https://pythonhosted.org/mord/
[statsmodels-ordinal-regression]: https://www.statsmodels.org/dev/examples/notebooks/generated/ordinal_regression.html
[ucla-webpage]: https://stats.oarc.ucla.edu/r/dae/ordinal-logistic-regression/
[calibrated-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ sklego = ["data/*.zip"]
[tool.ruff]
line-length = 120
extend-select = ["I"]
exclude = ["docs"]

[tool.pytest.ini_options]
markers = [
Expand Down
72 changes: 60 additions & 12 deletions sklego/meta/ordinal_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from sklearn import clone
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, MultiOutputMixin, is_classifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.metrics import accuracy_score
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y


class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand All @@ -26,8 +27,9 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B
more like a confidence score).
We recommend to use `CalibratedClassifierCV` to calibrate the probabilities of the binary classifiers.
This is enabled by default, but can be disabled by setting `use_calibration=False` and passing a calibrated
classifier to the `OrdinalClassifier` constructor.
You can enable this by setting `use_calibration=True` and passing an uncalibrated classifier to the
`OrdinalClassifier` or by passing a calibrated classifier to the `OrdinalClassifier` constructor.
More on this topic can be found in the [scikit-learn documentation](https://scikit-learn.org/stable/modules/calibration.html).
Expand All @@ -41,9 +43,16 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B
estimator : scikit-learn compatible classifier
The estimator to be applied to the data, used as binary classifier.
n_jobs : int, default=None
The number of jobs to run in parallel. `None` means 1, `-1` means using all processors.
use_calibration : bool, default=True
The number of jobs to run in parallel. The same convention of [`joblib.Parallel`](https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html)
holds:
- `n_jobs = None`: interpreted as n_jobs=1.
- `n_jobs > 0`: n_cpus=n_jobs are used.
- `n_jobs < 0`: (n_cpus + 1 + n_jobs) are used.
use_calibration : bool, default=False
Whether or not to calibrate the binary classifiers using `CalibratedClassifierCV`.
calibrarion_kwargs : dict | None, default=None
Keyword arguments to the `CalibratedClassifierCV` class, used only if `use_calibration=True`.
Attributes
----------
Expand All @@ -57,12 +66,23 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B
Examples
--------
```py
from sklego.meta import OrdinalClassifier
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklego.meta import OrdinalClassifier
...
url = "https://stats.idre.ucla.edu/stat/data/ologit.dta"
df = pd.read_stata(url).assign(apply_codes = lambda t: t["apply"].cat.codes)
clf = OrdinalClassifier(LogisticRegression())
target = "apply_codes"
features = [c for c in df.columns if c not in {target, "apply"}]
X, y = df[features].to_numpy(), df[target].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = OrdinalClassifier(LogisticRegression(), n_jobs=-1)
_ = clf.fit(X_train, y_train)
clf.predict_proba(X_test)
```
Expand All @@ -74,10 +94,11 @@ class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, B
"""

def __init__(self, estimator, *, n_jobs=None, use_calibration=True):
def __init__(self, estimator, *, n_jobs=None, use_calibration=False, **calibrarion_kwargs):
self.estimator = estimator
self.n_jobs = n_jobs
self.use_calibration = use_calibration
self.calibrarion_kwargs = calibrarion_kwargs

def fit(self, X, y):
"""Fit the `OrdinalClassifier` model on training data `X` and `y` by fitting its underlying estimators on
Expand Down Expand Up @@ -112,6 +133,9 @@ def fit(self, X, y):
self.classes_ = np.sort(np.unique(y))
self.n_features_in_ = X.shape[1]

if self.n_classes_ < 2:
raise ValueError("Classifier can't train when only one class is present.")

if self.n_jobs is None or self.n_jobs == 1:
self.estimators_ = {y_label: self._fit_binary_estimator(X, y, y_label) for y_label in self.classes_[:-1]}
else:
Expand All @@ -123,6 +147,7 @@ def fit(self, X, y):
),
)
)

return self

def predict_proba(self, X):
Expand All @@ -146,6 +171,7 @@ def predict_proba(self, X):
If `X` has a different number of features than the one seen during `fit`.
"""
check_is_fitted(self, ["estimators_", "classes_"])
X = check_array(X, ensure_2d=True, estimator=self)

if X.shape[1] != self.n_features_in_:
raise ValueError(f"X has {X.shape[1]} features, expected {self.n_features_in_} features.")
Expand All @@ -158,6 +184,7 @@ def predict_proba(self, X):

def predict(self, X):
"""Predict class labels for samples in `X` as the class with the highest probability."""
check_is_fitted(self, ["estimators_", "classes_"])
return self.classes_[np.argmax(self.predict_proba(X), axis=1)]

def _fit_binary_estimator(self, X, y, y_label):
Expand All @@ -178,8 +205,29 @@ def _fit_binary_estimator(self, X, y, y_label):
The fitted binary classifier.
"""
y_bin = (y <= y_label).astype(int)
fitted_model = clone(self.estimator).fit(X, y_bin)
if self.use_calibration:
return CalibratedClassifierCV(fitted_model, cv="prefit").fit(X, y_bin)
return CalibratedClassifierCV(estimator=clone(self.estimator), **self.calibrarion_kwargs).fit(X, y_bin)
else:
return fitted_model
return clone(self.estimator).fit(X, y_bin)

def score(self, X, y):
"""Returns the accuracy score on the given test data and labels.
Parameters
----------
X : array-like of shape (n_samples, n_features )
The training data.
y : array-like of shape (n_samples,)
The target values.
Returns
-------
score : float
Accuracy score of self.predict(X) wrt. y.
"""
return accuracy_score(y, self.predict(X))

@property
def n_classes_(self):
"""Number of classes."""
return len(self.classes_)
2 changes: 1 addition & 1 deletion tests/test_meta/test_grouped_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_multiple_grouping_columns(dataset_with_multiple_grouping, scaling_range

assert np.allclose(df_with_groups.groupby(["A", "B"]).min(), scaling_range[0])

# If a group has a single element, it defaults to min, so check wether all maxes are one of the bounds
# If a group has a single element, it defaults to min, so check whether all maxes are one of the bounds
maxes = df_with_groups.groupby(["A", "B"]).max()
assert np.all(
np.isclose(maxes, scaling_range[1]) | np.isclose(maxes, scaling_range[0])
Expand Down
Loading

0 comments on commit 95cfd85

Please sign in to comment.