Skip to content

Commit

Permalink
feat: support pre-fitted estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Apr 21, 2024
1 parent 0aa00df commit facb436
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
26 changes: 17 additions & 9 deletions src/conformal_tights/_conformal_coherent_quantile_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import numpy.typing as npt
from sklearn.base import BaseEstimator, MetaEstimatorMixin, RegressorMixin, clone
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import (
check_array,
Expand Down Expand Up @@ -134,15 +135,22 @@ def fit(
self.sample_weight_calib_l1_, self.sample_weight_calib_l2_ = (
sample_weights_calib[:2] if sample_weight is not None else (None, None) # type: ignore[has-type]
)
# Fit the given estimator on the training data.
self.estimator_ = (
clone(self.estimator)
if self.estimator != "auto"
else XGBRegressor(objective="reg:absoluteerror")
)
if isinstance(self.estimator_, XGBRegressor):
self.estimator_.set_params(enable_categorical=True, random_state=self.random_state)
self.estimator_.fit(X_train, y_train, sample_weight=sample_weight_train)
# Check if the estimator was pre-fitted.
try:
check_is_fitted(self.estimator)
except (NotFittedError, TypeError):
# Fit the given estimator on the training data.
self.estimator_ = (
clone(self.estimator)
if self.estimator != "auto"
else XGBRegressor(objective="reg:absoluteerror")
)
if isinstance(self.estimator_, XGBRegressor):
self.estimator_.set_params(enable_categorical=True, random_state=self.random_state)
self.estimator_.fit(X_train, y_train, sample_weight=sample_weight_train)
else:
# Use the pre-fitted estimator.
self.estimator_ = self.estimator
# Fit a nonconformity estimator on the training data with XGBRegressor's vector quantile
# regression. We fit a minimal number of quantiles to reduce the computational cost, but
# also to reduce the risk of overfitting in the coherent quantile regressor that is applied
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test fixtures."""

from typing import TypeAlias
from typing import Literal, TypeAlias

import pandas as pd
import pytest
Expand Down Expand Up @@ -39,12 +39,12 @@ def dataset(request: SubRequest) -> Dataset:

@pytest.fixture(
params=[
pytest.param(XGBRegressor(objective="reg:absoluteerror"), id="model:XGBRegressor-L1"),
pytest.param("auto", id="model:auto"),
pytest.param(XGBRegressor(objective="reg:squarederror"), id="model:XGBRegressor-L2"),
pytest.param(LGBMRegressor(objective="regression_l1"), id="model:LGBMRegressor-L1"),
pytest.param(LGBMRegressor(objective="regression_l2"), id="model:LGBMRegressor-L2"),
]
)
def regressor(request: SubRequest) -> BaseEstimator:
def regressor(request: SubRequest) -> BaseEstimator | Literal["auto"]:
"""Return a scikit-learn regressor."""
return request.param

0 comments on commit facb436

Please sign in to comment.