diff --git a/onedal/basic_statistics/basic_statistics.cpp b/onedal/basic_statistics/basic_statistics.cpp index 3f10fd0893..966850b69c 100644 --- a/onedal/basic_statistics/basic_statistics.cpp +++ b/onedal/basic_statistics/basic_statistics.cpp @@ -224,6 +224,7 @@ ONEDAL_PY_INIT_MODULE(basic_statistics) { #ifdef ONEDAL_DATA_PARALLEL_SPMD ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_spmd, task::compute); + ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_spmd, task::compute); #else // ONEDAL_DATA_PARALLEL_SPMD ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute); ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task::compute); diff --git a/onedal/basic_statistics/incremental_basic_statistics.py b/onedal/basic_statistics/incremental_basic_statistics.py index cbc7019321..b3c304a2af 100644 --- a/onedal/basic_statistics/incremental_basic_statistics.py +++ b/onedal/basic_statistics/incremental_basic_statistics.py @@ -66,12 +66,12 @@ class IncrementalBasicStatistics(BaseBasicStatistics): def __init__(self, result_options="all"): super().__init__(result_options, algorithm="by_default") - module = self._get_backend("basic_statistics") - self._partial_result = module.partial_compute_result() + self._reset() def _reset(self): - module = self._get_backend("basic_statistics") - self._partial_result = module.partial_train_result() + self._partial_result = self._get_backend( + "basic_statistics", None, "partial_compute_result" + ) def partial_fit(self, X, weights=None, queue=None): """ @@ -92,19 +92,20 @@ def partial_fit(self, X, weights=None, queue=None): self : object Returns the instance itself. """ - if not hasattr(self, "_policy"): - self._policy = self._get_policy(queue, X) - - X, weights = _convert_to_supported(self._policy, X, weights) + self._queue = queue + policy = self._get_policy(queue, X) + X, weights = _convert_to_supported(policy, X, weights) if not hasattr(self, "_onedal_params"): dtype = get_dtype(X) - self._onedal_params = self._get_onedal_params(dtype) + self._onedal_params = self._get_onedal_params(False, dtype=dtype) X_table, weights_table = to_table(X, weights) - module = self._get_backend("basic_statistics") - self._partial_result = module.partial_compute( - self._policy, + self._partial_result = self._get_backend( + "basic_statistics", + None, + "partial_compute", + policy, self._onedal_params, self._partial_result, X_table, @@ -119,16 +120,26 @@ def finalize_fit(self, queue=None): Parameters ---------- queue : dpctl.SyclQueue - Not used here, added for API conformance + If not None, use this queue for computations. Returns ------- self : object Returns the instance itself. """ - module = self._get_backend("basic_statistics") - result = module.finalize_compute( - self._policy, self._onedal_params, self._partial_result + + if queue is not None: + policy = self._get_policy(queue) + else: + policy = self._get_policy(self._queue) + + result = self._get_backend( + "basic_statistics", + None, + "finalize_compute", + policy, + self._onedal_params, + self._partial_result, ) options = self._get_result_options(self.options).split("|") for opt in options: diff --git a/onedal/spmd/basic_statistics/__init__.py b/onedal/spmd/basic_statistics/__init__.py index 75a6fdf5fa..c756a35f6e 100644 --- a/onedal/spmd/basic_statistics/__init__.py +++ b/onedal/spmd/basic_statistics/__init__.py @@ -15,5 +15,6 @@ # ============================================================================== from .basic_statistics import BasicStatistics +from .incremental_basic_statistics import IncrementalBasicStatistics -__all__ = ["BasicStatistics"] +__all__ = ["BasicStatistics", "IncrementalBasicStatistics"] diff --git a/onedal/spmd/basic_statistics/basic_statistics.py b/onedal/spmd/basic_statistics/basic_statistics.py index 8103c570b5..943da2bf92 100644 --- a/onedal/spmd/basic_statistics/basic_statistics.py +++ b/onedal/spmd/basic_statistics/basic_statistics.py @@ -14,8 +14,6 @@ # limitations under the License. # ============================================================================== -import warnings - from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch from ..._device_offload import support_usm_ndarray diff --git a/onedal/spmd/basic_statistics/incremental_basic_statistics.py b/onedal/spmd/basic_statistics/incremental_basic_statistics.py new file mode 100644 index 0000000000..a0bd62868a --- /dev/null +++ b/onedal/spmd/basic_statistics/incremental_basic_statistics.py @@ -0,0 +1,69 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from daal4py.sklearn._utils import get_dtype + +from ...basic_statistics import ( + IncrementalBasicStatistics as base_IncrementalBasicStatistics, +) +from ...datatypes import _convert_to_supported, to_table +from .._base import BaseEstimatorSPMD + + +class IncrementalBasicStatistics(BaseEstimatorSPMD, base_IncrementalBasicStatistics): + def _reset(self): + self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend( + "basic_statistics", None, "partial_compute_result" + ) + + def partial_fit(self, X, weights=None, queue=None): + """ + Computes partial data for basic statistics + from data batch X and saves it to `_partial_result`. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data batch, where `n_samples` is the number of samples + in the batch, and `n_features` is the number of features. + + queue : dpctl.SyclQueue + If not None, use this queue for computations. + + Returns + ------- + self : object + Returns the instance itself. + """ + self._queue = queue + policy = super(base_IncrementalBasicStatistics, self)._get_policy(queue, X) + X, weights = _convert_to_supported(policy, X, weights) + + if not hasattr(self, "_onedal_params"): + dtype = get_dtype(X) + self._onedal_params = self._get_onedal_params(False, dtype=dtype) + + X_table, weights_table = to_table(X, weights) + self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend( + "basic_statistics", + None, + "partial_compute", + policy, + self._onedal_params, + self._partial_result, + X_table, + weights_table, + ) diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index 805bc4d716..2ffa143421 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -120,7 +120,7 @@ def __init__(self, result_options="all", batch_size=None): def _onedal_supported(self, method_name, *data): patching_status = PatchingConditionsChain( - f"sklearn.covariance.{self.__class__.__name__}.{method_name}" + f"sklearn.basic_statistics.{self.__class__.__name__}.{method_name}" ) return patching_status @@ -135,9 +135,9 @@ def _get_onedal_result_options(self, options): assert isinstance(onedal_options, str) return options - def _onedal_finalize_fit(self): + def _onedal_finalize_fit(self, queue=None): assert hasattr(self, "_onedal_estimator") - self._onedal_estimator.finalize_fit() + self._onedal_estimator.finalize_fit(queue=queue) self._need_to_finalize = False def _onedal_partial_fit(self, X, sample_weight=None, queue=None): @@ -171,7 +171,7 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None): self._onedal_estimator = self._onedal_incremental_basic_statistics( **onedal_params ) - self._onedal_estimator.partial_fit(X, sample_weight, queue) + self._onedal_estimator.partial_fit(X, weights=sample_weight, queue=queue) self._need_to_finalize = True def _onedal_fit(self, X, sample_weight=None, queue=None): @@ -203,7 +203,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): self.n_features_in_ = X.shape[1] - self._onedal_finalize_fit() + self._onedal_finalize_fit(queue=queue) return self diff --git a/sklearnex/spmd/basic_statistics/__init__.py b/sklearnex/spmd/basic_statistics/__init__.py index 75a6fdf5fa..c756a35f6e 100644 --- a/sklearnex/spmd/basic_statistics/__init__.py +++ b/sklearnex/spmd/basic_statistics/__init__.py @@ -15,5 +15,6 @@ # ============================================================================== from .basic_statistics import BasicStatistics +from .incremental_basic_statistics import IncrementalBasicStatistics -__all__ = ["BasicStatistics"] +__all__ = ["BasicStatistics", "IncrementalBasicStatistics"] diff --git a/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py b/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py new file mode 100644 index 0000000000..e161cdc173 --- /dev/null +++ b/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py @@ -0,0 +1,30 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from onedal.spmd.basic_statistics import ( + IncrementalBasicStatistics as onedalSPMD_IncrementalBasicStatistics, +) + +from ...basic_statistics import ( + IncrementalBasicStatistics as base_IncrementalBasicStatistics, +) + + +class IncrementalBasicStatistics(base_IncrementalBasicStatistics): + _onedal_incremental_basic_statistics = staticmethod( + onedalSPMD_IncrementalBasicStatistics + ) diff --git a/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py b/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py new file mode 100644 index 0000000000..63060e4e9b --- /dev/null +++ b/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py @@ -0,0 +1,307 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from onedal.basic_statistics.tests.test_basic_statistics import options_and_tests +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex.tests._utils_spmd import ( + _generate_statistic_data, + _get_local_tensor, + _mpi_libs_and_gpu_available, +) + + +@pytest.mark.skipif( + not _mpi_libs_and_gpu_available, + reason="GPU device and MPI libs required for test", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"), +) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.mpi +def test_incremental_basic_statistics_fit_spmd_gold(dataframe, queue, weighted, dtype): + # Import spmd and batch algo + from sklearnex.basic_statistics import IncrementalBasicStatistics + from sklearnex.spmd.basic_statistics import ( + IncrementalBasicStatistics as IncrementalBasicStatistics_SPMD, + ) + + # Create gold data and process into dpt + data = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [0.0, 2.0, 4.0], + [0.0, 3.0, 8.0], + [0.0, 4.0, 16.0], + [0.0, 5.0, 32.0], + [0.0, 6.0, 64.0], + [0.0, 7.0, 128.0], + ], + dtype=dtype, + ) + dpt_data = _convert_to_dataframe(data, sycl_queue=queue, target_df=dataframe) + + local_dpt_data = _convert_to_dataframe( + _get_local_tensor(data), sycl_queue=queue, target_df=dataframe + ) + + if weighted: + # Create weights array containing the weight for each sample in the data + weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype) + dpt_weights = _convert_to_dataframe( + weights, sycl_queue=queue, target_df=dataframe + ) + local_dpt_weights = _convert_to_dataframe( + _get_local_tensor(weights), sycl_queue=queue, target_df=dataframe + ) + + # ensure results of batch algo match spmd + + incbs_spmd = IncrementalBasicStatistics_SPMD().fit( + local_dpt_data, sample_weight=local_dpt_weights if weighted else None + ) + incbs = IncrementalBasicStatistics().fit( + dpt_data, sample_weight=dpt_weights if weighted else None + ) + + for option, _, _ in options_and_tests: + assert_allclose( + getattr(incbs_spmd, option), + getattr(incbs, option), + err_msg=f"Result for {option} is incorrect", + ) + + +@pytest.mark.skipif( + not _mpi_libs_and_gpu_available, + reason="GPU device and MPI libs required for test", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"), +) +@pytest.mark.parametrize("num_blocks", [1, 2]) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.mpi +def test_incremental_basic_statistics_partial_fit_spmd_gold( + dataframe, queue, num_blocks, weighted, dtype +): + # Import spmd and batch algo + from sklearnex.basic_statistics import IncrementalBasicStatistics + from sklearnex.spmd.basic_statistics import ( + IncrementalBasicStatistics as IncrementalBasicStatistics_SPMD, + ) + + # Create gold data and process into dpt + data = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [0.0, 2.0, 4.0], + [0.0, 3.0, 8.0], + [0.0, 4.0, 16.0], + [0.0, 5.0, 32.0], + [0.0, 6.0, 64.0], + [0.0, 7.0, 128.0], + ], + dtype=dtype, + ) + dpt_data = _convert_to_dataframe(data, sycl_queue=queue, target_df=dataframe) + local_data = _get_local_tensor(data) + split_local_data = np.array_split(local_data, num_blocks) + + if weighted: + # Create weights array containing the weight for each sample in the data + weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype) + dpt_weights = _convert_to_dataframe( + weights, sycl_queue=queue, target_df=dataframe + ) + local_weights = _get_local_tensor(weights) + split_local_weights = np.array_split(local_weights, num_blocks) + + incbs_spmd = IncrementalBasicStatistics_SPMD() + incbs = IncrementalBasicStatistics() + + for i in range(num_blocks): + local_dpt_data = _convert_to_dataframe( + split_local_data[i], sycl_queue=queue, target_df=dataframe + ) + if weighted: + local_dpt_weights = _convert_to_dataframe( + split_local_weights[i], sycl_queue=queue, target_df=dataframe + ) + incbs_spmd.partial_fit( + local_dpt_data, sample_weight=local_dpt_weights if weighted else None + ) + + incbs.fit(dpt_data, sample_weight=dpt_weights if weighted else None) + + for option, _, _ in options_and_tests: + assert_allclose( + getattr(incbs_spmd, option), + getattr(incbs, option), + err_msg=f"Result for {option} is incorrect", + ) + + +@pytest.mark.skipif( + not _mpi_libs_and_gpu_available, + reason="GPU device and MPI libs required for test", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"), +) +@pytest.mark.parametrize("num_blocks", [1, 2]) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("result_option", options_and_tests) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.mpi +def test_incremental_basic_statistics_single_option_partial_fit_spmd_gold( + dataframe, queue, num_blocks, weighted, result_option, dtype +): + # Import spmd and batch algo + from sklearnex.basic_statistics import IncrementalBasicStatistics + from sklearnex.spmd.basic_statistics import ( + IncrementalBasicStatistics as IncrementalBasicStatistics_SPMD, + ) + + # Create gold data and process into dpt + data = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [0.0, 2.0, 4.0], + [0.0, 3.0, 8.0], + [0.0, 4.0, 16.0], + [0.0, 5.0, 32.0], + [0.0, 6.0, 64.0], + [0.0, 7.0, 128.0], + ], + dtype=dtype, + ) + dpt_data = _convert_to_dataframe(data, sycl_queue=queue, target_df=dataframe) + local_data = _get_local_tensor(data) + split_local_data = np.array_split(local_data, num_blocks) + + if weighted: + # Create weights array containing the weight for each sample in the data + weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype) + dpt_weights = _convert_to_dataframe( + weights, sycl_queue=queue, target_df=dataframe + ) + local_weights = _get_local_tensor(weights) + split_local_weights = np.array_split(local_weights, num_blocks) + + option, _, _ = result_option + incbs_spmd = IncrementalBasicStatistics_SPMD(result_options=option) + incbs = IncrementalBasicStatistics(result_options=option) + + for i in range(num_blocks): + local_dpt_data = _convert_to_dataframe( + split_local_data[i], sycl_queue=queue, target_df=dataframe + ) + if weighted: + local_dpt_weights = _convert_to_dataframe( + split_local_weights[i], sycl_queue=queue, target_df=dataframe + ) + incbs_spmd.partial_fit( + local_dpt_data, sample_weight=local_dpt_weights if weighted else None + ) + + incbs.fit(dpt_data, sample_weight=dpt_weights if weighted else None) + + assert_allclose(getattr(incbs_spmd, option), getattr(incbs, option)) + + +@pytest.mark.skipif( + not _mpi_libs_and_gpu_available, + reason="GPU device and MPI libs required for test", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"), +) +@pytest.mark.parametrize("num_blocks", [1, 2]) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("n_samples", [100, 10000]) +@pytest.mark.parametrize("n_features", [10, 100]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.mpi +def test_incremental_basic_statistics_partial_fit_spmd_synthetic( + dataframe, queue, num_blocks, weighted, n_samples, n_features, dtype +): + # Import spmd and batch algo + from sklearnex.basic_statistics import IncrementalBasicStatistics + from sklearnex.spmd.basic_statistics import ( + IncrementalBasicStatistics as IncrementalBasicStatistics_SPMD, + ) + + tol = 2e-3 if dtype == np.float32 else 1e-7 + + # Create gold data and process into dpt + data = _generate_statistic_data(n_samples, n_features, dtype=dtype) + local_data = _get_local_tensor(data) + split_local_data = np.array_split(local_data, num_blocks) + split_data = np.array_split(data, num_blocks) + + if weighted: + # Create weights array containing the weight for each sample in the data + weights = _generate_statistic_data(n_samples, dtype=dtype) + local_weights = _get_local_tensor(weights) + split_local_weights = np.array_split(local_weights, num_blocks) + split_weights = np.array_split(weights, num_blocks) + + incbs_spmd = IncrementalBasicStatistics_SPMD() + incbs = IncrementalBasicStatistics() + + for i in range(num_blocks): + local_dpt_data = _convert_to_dataframe( + split_local_data[i], sycl_queue=queue, target_df=dataframe + ) + dpt_data = _convert_to_dataframe( + split_data[i], sycl_queue=queue, target_df=dataframe + ) + if weighted: + local_dpt_weights = _convert_to_dataframe( + split_local_weights[i], sycl_queue=queue, target_df=dataframe + ) + dpt_weights = _convert_to_dataframe( + split_weights[i], sycl_queue=queue, target_df=dataframe + ) + incbs_spmd.partial_fit( + local_dpt_data, sample_weight=local_dpt_weights if weighted else None + ) + incbs.partial_fit(dpt_data, sample_weight=dpt_weights if weighted else None) + + for option, _, _ in options_and_tests: + assert_allclose( + getattr(incbs_spmd, option), + getattr(incbs, option), + atol=tol, + err_msg=f"Result for {option} is incorrect", + ) diff --git a/sklearnex/tests/_utils_spmd.py b/sklearnex/tests/_utils_spmd.py index 172db788be..4bdd4d4fd5 100644 --- a/sklearnex/tests/_utils_spmd.py +++ b/sklearnex/tests/_utils_spmd.py @@ -89,10 +89,16 @@ def _generate_classification_data( return X_train, X_test, y_train, y_test -def _generate_statistic_data(n_samples, n_features, dtype=np.float64, random_state=42): +def _generate_statistic_data( + n_samples, n_features=None, dtype=np.float64, random_state=42 +): # Generates statistical data gen = np.random.default_rng(random_state) - data = gen.uniform(low=-0.3, high=+0.7, size=(n_samples, n_features)).astype(dtype) + data = gen.uniform( + low=-0.3, + high=+0.7, + size=(n_samples, n_features) if n_features is not None else (n_samples,), + ).astype(dtype) return data