Skip to content

Commit

Permalink
ENH: initiate jacknife gp
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentblot28 committed Jul 28, 2023
1 parent 97a8cba commit 878e56e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
39 changes: 32 additions & 7 deletions mapie/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, clone
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.model_selection import BaseCrossValidator, ShuffleSplit
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import (_num_samples, check_is_fitted)
Expand Down Expand Up @@ -246,10 +247,16 @@ def _predict_oof_estimator(
"""
X_val = _safe_indexing(X, val_index)
if _num_samples(X_val) > 0:
y_pred = estimator.predict(X_val)
if isinstance(estimator, GaussianProcessRegressor):
y_pred, y_std = estimator.predict(X_val, return_std=True)
else:
y_pred = estimator.predict(X_val)
else:
y_pred = np.array([])
return y_pred, val_index
if isinstance(estimator, GaussianProcessRegressor):
return [y_pred, y_std], val_index
else:
return y_pred, val_index

def _aggregate_with_mask(
self,
Expand Down Expand Up @@ -360,16 +367,34 @@ def predict_calib(self, X: ArrayLike) -> NDArray:
fill_value=np.nan,
dtype=float,
)
for i, ind in enumerate(indices):
pred_matrix[ind, i] = np.array(
predictions[i], dtype=float
if isinstance(self.estimator, GaussianProcessRegressor):
std_matrix = np.full(
shape=(n_samples, cv.get_n_splits(X)),
fill_value=np.nan,
dtype=float,
)
for i, ind in enumerate(indices):
if not isinstance(self.estimator, GaussianProcessRegressor):
pred_matrix[ind, i] = np.array(
predictions[i], dtype=float
)
else:
pred_matrix[ind, i] = np.array(
predictions[i][0], dtype=float
)
std_matrix[ind, i] = np.array(
predictions[i][1], dtype=float
)
self.k_[ind, i] = 1
check_nan_in_aposteriori_prediction(pred_matrix)

y_pred = aggregate_all(self.agg_function, pred_matrix)

return y_pred
if isinstance(self.estimator, GaussianProcessRegressor):
y_std = aggregate_all(self.agg_function, std_matrix)
if isinstance(self.estimator, GaussianProcessRegressor):
return y_pred, y_std
else:
return y_pred

def fit(
self,
Expand Down
10 changes: 7 additions & 3 deletions mapie/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import BaseCrossValidator
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -507,12 +508,15 @@ def fit(
)
# Fit the prediction function
self.estimator_ = self.estimator_.fit(X, y, sample_weight)
y_pred = self.estimator_.predict_calib(X)

if isinstance(estimator, GaussianProcessRegressor):
y_pred, y_std = self.estimator_.predict_calib(X)
else:
y_pred = self.estimator_.predict_calib(X)
y_std = None
# Compute the conformity scores (manage jk-ab case)
self.conformity_scores_ = \
self.conformity_score_function_.get_conformity_scores(
X, y, y_pred
X, y, y_pred, y_std
)

return self
Expand Down
15 changes: 15 additions & 0 deletions mapie/regression/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.model_selection import train_test_split
from mapie.regression import MapieRegressor

def g(x):
return (3 * x * np.sin(x) - 2 * x * np.cos(x) + ( x ** 3) / 40 - .5 * x ** 2 - 10 * x)

X = np.linspace(-40, 60, 100).reshape(-1, 1)
y = g(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2)
gp = GaussianProcessRegressor().fit(X_train, y_train)

mapie = MapieRegressor(estimator=gp, method="plus", cv=-1)
mapie.fit(X_train, y_train)

0 comments on commit 878e56e

Please sign in to comment.