Skip to content

Commit

Permalink
Apply comments for AdaBoost and GBT estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexsandruss committed Dec 12, 2024
1 parent 286e8b1 commit 04a9981
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions daal4py/sklearn/ensemble/AdaBoostClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def fit(self, X, y):
)

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])

check_classification_targets(y)

Expand Down Expand Up @@ -157,7 +157,7 @@ def predict(self, X):
check_is_fitted(self)

# Input validation
X = validate_data(self, X, dtype=[np.single, np.double], reset=False)
X = validate_data(self, X, dtype=[np.float64, np.float32], reset=False)

# Trivial case
if self.n_classes_ == 1:
Expand Down
17 changes: 9 additions & 8 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ def _check_params(self):
def _more_tags(self):
return {"allow_nan": self.allow_nan_}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.allow_nan_
return tags
if sklearn_check_version("1.6"):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.allow_nan_
return tags


@control_n_jobs(decorated_methods=["fit", "predict"])
Expand All @@ -147,7 +148,7 @@ def fit(self, X, y):
self._check_params()

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])

check_classification_targets(y)

Expand Down Expand Up @@ -214,7 +215,7 @@ def _predict(
X = validate_data(
self,
X,
dtype=[np.single, np.double],
dtype=[np.float64, np.float32],
force_all_finite="allow-nan" if self.allow_nan_ else True,
reset=False,
)
Expand Down Expand Up @@ -271,7 +272,7 @@ def fit(self, X, y):
self._check_params()

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=True, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=True, dtype=[np.float64, np.float32])

# Convert to 2d array
y_ = y.reshape((-1, 1))
Expand Down Expand Up @@ -318,7 +319,7 @@ def predict(self, X, pred_contribs=False, pred_interactions=False):
X = validate_data(
self,
X,
dtype=[np.single, np.double],
dtype=[np.float64, np.float32],
force_all_finite="allow-nan" if self.allow_nan_ else True,
reset=False,
)
Expand Down

0 comments on commit 04a9981

Please sign in to comment.