From 84afb62936abbbd2d5a83527cb12c8a7265ad38f Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Tue, 17 Dec 2024 05:36:22 -0800 Subject: [PATCH] undo more changes that broke tests --- onedal/ensemble/forest.py | 36 +++++--------------- sklearnex/ensemble/_forest.py | 63 +++++++++++++++++------------------ 2 files changed, 38 insertions(+), 61 deletions(-) diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index 5327dedecb..ddb214b010 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -330,20 +330,10 @@ def _fit(self, X, y, sample_weight, module, queue): if self.oob_score: if isinstance(self, ClassifierMixin): - # self.oob_score_ = from_table(train_result.oob_err_accuracy).item() - self.oob_score_ = from_table( - train_result.oob_err_accuracy, - sua_iface=sua_iface, - sycl_queue=queue, - xp=xp, - )[0] - + self.oob_score_ = from_table(train_result.oob_err_accuracy).item() self.oob_decision_function_ = from_table( - train_result.oob_err_decision_function, - sua_iface=sua_iface, - sycl_queue=queue, - xp=xp, - )[0] + train_result.oob_err_decision_function + ) if xp.any(self.oob_decision_function_ == 0): warnings.warn( "Some inputs do not have OOB scores. This probably means " @@ -352,21 +342,11 @@ def _fit(self, X, y, sample_weight, module, queue): UserWarning, ) else: - # self.oob_score_ = from_table(train_result.oob_err_r2).item() - self.oob_score_ = from_table( - train_result.oob_err_r2, sua_iface=sua_iface, sycl_queue=queue, xp=xp - )[0] - # self.oob_prediction_ = from_table(train_result.oob_err_prediction).reshape(-1) - self.oob_score_ = xp.reshape( - from_table( - train_result.oob_err_r2, - sua_iface=sua_iface, - sycl_queue=queue, - xp=xp, - ), - -1, - ) - if xp.any(self.oob_prediction_ == 0): + self.oob_score_ = from_table(train_result.oob_err_r2).item() + self.oob_prediction_ = from_table( + train_result.oob_err_prediction + ).reshape(-1) + if np.any(self.oob_prediction_ == 0): warnings.warn( "Some inputs do not have OOB scores. This probably means " "too few trees were used to compute any reliable OOB " diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 455d41d9ae..ab62c219a4 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -812,42 +812,39 @@ def _onedal_gpu_supported(self, method_name, *data): return patching_status def _onedal_predict(self, X, queue=None): - xp, _ = get_namespace(X) - use_raw_input = get_config()["use_raw_input"] - if not use_raw_input: - if sklearn_check_version("1.0"): - X = validate_data( - self, + if sklearn_check_version("1.0"): + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + reset=False, + ensure_2d=True, + ) + else: + if not get_config()["use_raw_input"]: + X = check_array( X, dtype=[np.float64, np.float32], force_all_finite=False, - reset=False, - ensure_2d=True, - ) - # sklearn version < 1.0 is not supported - # else: - # X = check_array( - # X, - # dtype=[np.float64, np.float32], - # force_all_finite=False, - # ) # Warning, order of dtype matters - # if hasattr(self, "n_features_in_"): - # try: - # num_features = _num_features(X) - # except TypeError: - # num_features = _num_samples(X) - # if num_features != self.n_features_in_: - # raise ValueError( - # ( - # f"X has {num_features} features, " - # f"but {self.__class__.__name__} is expecting " - # f"{self.n_features_in_} features as input" - # ) - # ) - # self._check_n_features(X, reset=False) - res = xp.reshape(self._onedal_estimator.predict(X, queue=queue), -1) - res = xp.astype(res, xp.int64) - return xp.take(self.classes_, res) + ) # Warning, order of dtype matters + if hasattr(self, "n_features_in_"): + try: + num_features = _num_features(X) + except TypeError: + num_features = _num_samples(X) + if num_features != self.n_features_in_: + raise ValueError( + ( + f"X has {num_features} features, " + f"but {self.__class__.__name__} is expecting " + f"{self.n_features_in_} features as input" + ) + ) + self._check_n_features(X, reset=False) + + res = self._onedal_estimator.predict(X, queue=queue) + return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe")) def _onedal_predict_proba(self, X, queue=None): xp, _ = get_namespace(X)