Skip to content

Commit

Permalink
feat: add Conformal Coherent Quantile Regression
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Mar 15, 2024
1 parent 4eb2d1b commit 71a8219
Show file tree
Hide file tree
Showing 10 changed files with 1,512 additions and 177 deletions.
117 changes: 115 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,124 @@

# 👖 Conformal Tights

A scikit-learn [meta-estimator](https://scikit-learn.org/stable/glossary.html#term-meta-estimator) for computing tight [conformal predictions](https://en.wikipedia.org/wiki/Conformal_prediction).
A [scikit-learn meta-estimator](https://scikit-learn.org/stable/glossary.html#term-meta-estimator) that adds [conformal prediction](https://en.wikipedia.org/wiki/Conformal_prediction) of coherent [quantiles](https://en.wikipedia.org/wiki/Quantile) and [intervals](https://en.wikipedia.org/wiki/Prediction_interval) to any [scikit-learn regressor](https://scikit-learn.org/stable/glossary.html#term-regressor). Features:

1. 🍬 *Meta-estimator*: add prediction of quantiles and intervals to any scikit-learn regressor
2. 🌡️ *Conformally calibrated:* accurate quantiles and intervals with reliable [coverage](https://en.wikipedia.org/wiki/Coverage_probability)
3. 🚦 *Coherent quantiles:* quantiles increase monotonically instead of [crossing](https://github.com/dmlc/xgboost/issues/9848) [each other](https://github.com/microsoft/LightGBM/issues/3447)
4. 👖 *Tight quantiles:* selects the lowest [dispersion](https://en.wikipedia.org/wiki/Statistical_dispersion) that provides the desired coverage
5. 🎁 *Data efficient:* requires only a small number of calibration examples to fit
6. 🐼 *Pandas support:* optionally predict on DataFrames and receive DataFrame output

## Using

To add and install this package as a dependency of your project, run `poetry add conformal-tights`.
### Installing

First, install this package with:

```sh
pip install conformal-tights
```

### Predicting quantiles

Conformal Tights exposes a meta-estimator called `ConformalCoherentQuantileRegressor` that you can use to wrap any scikit-learn regressor, after which you can use `predict_quantiles` predict conformally calibrated quantiles. Example usage:

```python
from conformal_tights import ConformalCoherentQuantileRegressor
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor

# Fetch dataset and split in train and test
X, y = fetch_openml("ames_housing", version=1, return_X_y=True, as_frame=True, parser="auto")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)

# Create a regressor, wrap it, and fit on the train set
my_regressor = XGBRegressor(objective="reg:absoluteerror")
conformal_predictor = ConformalCoherentQuantileRegressor(estimator=my_regressor)
conformal_predictor.fit(X_train, y_train)

# Predict with the wrapped regressor
ŷ_test = conformal_predictor.predict(X_test)

# Predict quantiles with the conformal wrapper
ŷ_test_quantiles = conformal_predictor.predict_quantiles(X_test, quantiles=(0.025, 0.05, 0.1, 0.9, 0.95, 0.975))
```

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_quantiles` yields:

| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
|-----------:|--------:|-------:|-------:|-------:|-------:|--------:|
| 1357 | 121557 | 130272 | 139913 | 189399 | 211177 | 237309 |
| 2367 | 86005 | 92617 | 98591 | 130236 | 145686 | 164766 |
| 2822 | 116523 | 121711 | 134993 | 175583 | 194964 | 216891 |
| 2126 | 105712 | 113784 | 122145 | 164330 | 183352 | 206224 |
| 1544 | 85920 | 92311 | 99130 | 133228 | 148895 | 167969 |

Let's visualize the predicted quantiles on the test set:

<img src="" width="512">

<details>
<summary>Expand to see the code that generated the graph above</summary>

```python
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%config InlineBackend.figure_format = "retina"
plt.rcParams["font.size"] = 8
idx = (-ŷ_test.sample(50, random_state=42)).sort_values().index
y_ticks = list(range(1, len(idx) + 1))
plt.figure(figsize=(4, 5))
for j in range(3):
end = ŷ_test_quantiles.shape[1] - 1 - j
coverage = round(100 * (ŷ_test_quantiles.columns[end] - ŷ_test_quantiles.columns[j]))
plt.barh(
y_ticks,
ŷ_test_quantiles.loc[idx].iloc[:, end] - ŷ_test_quantiles.loc[idx].iloc[:, j],
left=ŷ_test_quantiles.loc[idx].iloc[:, j],
label=f"{coverage}% Prediction interval",
color=["#b3d9ff", "#86bfff", "#4da6ff"][j],
)
plt.plot(y_test.loc[idx], y_ticks, "s", markersize=3, markerfacecolor="none", markeredgecolor="#e74c3c", label="Actual value")
plt.plot(ŷ_test.loc[idx], y_ticks, "s", color="blue", markersize=0.6, label="Predicted value")
plt.xlabel("House price")
plt.ylabel("Test house index")
plt.yticks(y_ticks, y_ticks)
plt.tick_params(axis="y", labelsize=6)
plt.grid(axis="x", color="lightsteelblue", linestyle=":", linewidth=0.5)
plt.gca().xaxis.set_major_formatter(ticker.StrMethodFormatter("${x:,.0f}"))
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.legend()
plt.tight_layout()
plt.show()
```
</details>

### Predicting intervals

In addition to quantile prediction, you can use `predict_interval` to predict conformally calibrated prediction intervals. Compared to quantiles, these focus on reliable coverage over quantile accuracy. Example usage:

```python
# Predict an interval for each example with the conformal wrapper
ŷ_test_interval = conformal_predictor.predict_interval(X_test, coverage=0.95)

# Measure the coverage of the prediction intervals on the test set
coverage = ((ŷ_test_interval.iloc[:, 0] <= y_test) & (y_test <= ŷ_test_interval.iloc[:, 1])).mean()
print(coverage) # 96.6%
```

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_interval` yields:

| house_id | 0.025 | 0.975 |
|-----------:|--------:|--------:|
| 1357 | 108489 | 238396 |
| 2367 | 76043 | 165189 |
| 2822 | 101319 | 220247 |
| 2126 | 94238 | 207501 |
| 1544 | 75976 | 168741 |

## Contributing

Expand Down
Loading

0 comments on commit 71a8219

Please sign in to comment.