Skip to content

Commit

Permalink
undo more changes that broke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Dec 17, 2024
1 parent 9964c5a commit 84afb62
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 61 deletions.
36 changes: 8 additions & 28 deletions onedal/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,20 +330,10 @@ def _fit(self, X, y, sample_weight, module, queue):

if self.oob_score:
if isinstance(self, ClassifierMixin):
# self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
self.oob_score_ = from_table(
train_result.oob_err_accuracy,
sua_iface=sua_iface,
sycl_queue=queue,
xp=xp,
)[0]

self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
self.oob_decision_function_ = from_table(
train_result.oob_err_decision_function,
sua_iface=sua_iface,
sycl_queue=queue,
xp=xp,
)[0]
train_result.oob_err_decision_function
)
if xp.any(self.oob_decision_function_ == 0):
warnings.warn(
"Some inputs do not have OOB scores. This probably means "
Expand All @@ -352,21 +342,11 @@ def _fit(self, X, y, sample_weight, module, queue):
UserWarning,
)
else:
# self.oob_score_ = from_table(train_result.oob_err_r2).item()
self.oob_score_ = from_table(
train_result.oob_err_r2, sua_iface=sua_iface, sycl_queue=queue, xp=xp
)[0]
# self.oob_prediction_ = from_table(train_result.oob_err_prediction).reshape(-1)
self.oob_score_ = xp.reshape(
from_table(
train_result.oob_err_r2,
sua_iface=sua_iface,
sycl_queue=queue,
xp=xp,
),
-1,
)
if xp.any(self.oob_prediction_ == 0):
self.oob_score_ = from_table(train_result.oob_err_r2).item()
self.oob_prediction_ = from_table(
train_result.oob_err_prediction
).reshape(-1)
if np.any(self.oob_prediction_ == 0):
warnings.warn(
"Some inputs do not have OOB scores. This probably means "
"too few trees were used to compute any reliable OOB "
Expand Down
63 changes: 30 additions & 33 deletions sklearnex/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,42 +812,39 @@ def _onedal_gpu_supported(self, method_name, *data):
return patching_status

def _onedal_predict(self, X, queue=None):
xp, _ = get_namespace(X)
use_raw_input = get_config()["use_raw_input"]
if not use_raw_input:
if sklearn_check_version("1.0"):
X = validate_data(
self,
if sklearn_check_version("1.0"):
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
force_all_finite=False,
reset=False,
ensure_2d=True,
)
else:
if not get_config()["use_raw_input"]:
X = check_array(

Check warning on line 826 in sklearnex/ensemble/_forest.py

View check run for this annotation

Codecov / codecov/patch

sklearnex/ensemble/_forest.py#L825-L826

Added lines #L825 - L826 were not covered by tests
X,
dtype=[np.float64, np.float32],
force_all_finite=False,
reset=False,
ensure_2d=True,
)
# sklearn version < 1.0 is not supported
# else:
# X = check_array(
# X,
# dtype=[np.float64, np.float32],
# force_all_finite=False,
# ) # Warning, order of dtype matters
# if hasattr(self, "n_features_in_"):
# try:
# num_features = _num_features(X)
# except TypeError:
# num_features = _num_samples(X)
# if num_features != self.n_features_in_:
# raise ValueError(
# (
# f"X has {num_features} features, "
# f"but {self.__class__.__name__} is expecting "
# f"{self.n_features_in_} features as input"
# )
# )
# self._check_n_features(X, reset=False)
res = xp.reshape(self._onedal_estimator.predict(X, queue=queue), -1)
res = xp.astype(res, xp.int64)
return xp.take(self.classes_, res)
) # Warning, order of dtype matters
if hasattr(self, "n_features_in_"):
try:
num_features = _num_features(X)
except TypeError:
num_features = _num_samples(X)
if num_features != self.n_features_in_:
raise ValueError(
(
f"X has {num_features} features, "
f"but {self.__class__.__name__} is expecting "
f"{self.n_features_in_} features as input"
)
)
self._check_n_features(X, reset=False)

res = self._onedal_estimator.predict(X, queue=queue)
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))

def _onedal_predict_proba(self, X, queue=None):
xp, _ = get_namespace(X)
Expand Down

0 comments on commit 84afb62

Please sign in to comment.