From 5fcc7fb514d72b51e5fc2a2d18272f6e31423cbe Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Thu, 28 Nov 2024 15:14:53 +0100 Subject: [PATCH] [enhancement] add oneDAL finiteness_checker implementation to onedal (#2126) * add finiteness_checker pybind11 bindings * added finiteness checker * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Update finiteness_checker.cpp * add next step * follow conventions * make xtable explicit * remove comment * Update validation.py * Update __init__.py * Update validation.py * Update __init__.py * Update __init__.py * Update validation.py * Update _data_conversion.py * Update _data_conversion.py * Update policy_common.cpp * Update policy_common.cpp * Update _policy.py * Update policy_common.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Create finiteness_checker.py * Update validation.py * Update __init__.py * attempt at fixing circular imports again * fix isort * remove __init__ changes * last move * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update validation.py * add testing * isort * attempt to fix module error * add fptype * fix typo * Update validation.py * remove sua_ifcae from to_table * isort and black * Update test_memory_usage.py * format * Update _data_conversion.py * Update _data_conversion.py * Update test_validation.py * remove unnecessary code * make reviewer changes * make dtype check change * add sparse testing * try again * try again * try again * Update onedal/utils/tests/test_validation.py Co-authored-by: Samir Nasibli * formatting * formatting again * add _check_sample_weight * Revert "add _check_sample_weight" This reverts commit 4efad2ce70dcdf708be1290b2678727125bd4857. * Update test_validation.py * Update validation.py * make changes * Update test_validation.py --------- Co-authored-by: Samir Nasibli --- onedal/dal.cpp | 6 ++ onedal/utils/finiteness_checker.cpp | 103 ++++++++++++++++++ onedal/utils/tests/test_validation.py | 144 ++++++++++++++++++++++++++ onedal/utils/validation.py | 43 +++++++- sklearnex/utils/tests/test_finite.py | 4 +- 5 files changed, 293 insertions(+), 7 deletions(-) create mode 100644 onedal/utils/finiteness_checker.cpp create mode 100644 onedal/utils/tests/test_validation.py diff --git a/onedal/dal.cpp b/onedal/dal.cpp index 5c63b1c225..298ab39fd9 100644 --- a/onedal/dal.cpp +++ b/onedal/dal.cpp @@ -79,6 +79,9 @@ namespace oneapi::dal::python { #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001 ONEDAL_PY_INIT_MODULE(logistic_regression); #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001 + #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 + ONEDAL_PY_INIT_MODULE(finiteness_checker); + #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 #endif // ONEDAL_DATA_PARALLEL_SPMD #ifdef ONEDAL_DATA_PARALLEL_SPMD @@ -138,6 +141,9 @@ namespace oneapi::dal::python { #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001 init_logistic_regression(m); #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001 + #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 + init_finiteness_checker(m); + #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 } #endif // ONEDAL_DATA_PARALLEL_SPMD diff --git a/onedal/utils/finiteness_checker.cpp b/onedal/utils/finiteness_checker.cpp new file mode 100644 index 0000000000..2b8d84bd6f --- /dev/null +++ b/onedal/utils/finiteness_checker.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// fix error with missing headers +#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200 + #include "oneapi/dal/algo/finiteness_checker.hpp" +#else + #include "oneapi/dal/algo/finiteness_checker/compute.hpp" +#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200 + +#include "onedal/common.hpp" +#include "onedal/version.hpp" + +namespace py = pybind11; + +namespace oneapi::dal::python { + +template +struct method2t { + method2t(const Task& task, const Ops& ops) : ops(ops) {} + + template + auto operator()(const py::dict& params) { + using namespace finiteness_checker; + + const auto method = params["method"].cast(); + + ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense); + ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default); + ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method); + } + + Ops ops; +}; + +struct params2desc { + template + auto operator()(const pybind11::dict& params) { + using namespace dal::finiteness_checker; + + auto desc = descriptor(); + desc.set_allow_NaN(params["allow_nan"].cast()); + return desc; + } +}; + +template +void init_compute_ops(py::module_& m) { + m.def("compute", + [](const Policy& policy, + const py::dict& params, + const table& data) { + using namespace finiteness_checker; + using input_t = compute_input; + + compute_ops ops(policy, input_t{ data }, params2desc{}); + return fptype2t{ method2t{ Task{}, ops } }(params); + }); +} + +template +void init_compute_result(py::module_& m) { + using namespace finiteness_checker; + using result_t = compute_result; + + py::class_(m, "compute_result") + .def(py::init()) + .DEF_ONEDAL_PY_PROPERTY(finite, result_t); +} + +ONEDAL_PY_TYPE2STR(finiteness_checker::task::compute, "compute"); + +ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_ops); +ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_result); + +ONEDAL_PY_INIT_MODULE(finiteness_checker) { + using namespace dal::detail; + using namespace finiteness_checker; + using namespace dal::finiteness_checker; + + using task_list = types; + auto sub = m.def_submodule("finiteness_checker"); + + #ifndef ONEDAL_DATA_PARALLEL_SPMD + ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list); + ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list); + #endif +} + +} // namespace oneapi::dal::python diff --git a/onedal/utils/tests/test_validation.py b/onedal/utils/tests/test_validation.py new file mode 100644 index 0000000000..1835cea3b6 --- /dev/null +++ b/onedal/utils/tests/test_validation.py @@ -0,0 +1,144 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time + +import numpy as np +import numpy.random as rand +import pytest +import scipy.sparse as sp + +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from onedal.utils.validation import assert_all_finite + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [65539], # 2**16 + 3, + [1000, 1000], + [ + 3, + ], + ], +) +@pytest.mark.parametrize("allow_nan", [False, True]) +@pytest.mark.parametrize( + "dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") +) +def test_sum_infinite_actually_finite(dtype, shape, allow_nan, dataframe, queue): + X = np.empty(shape, dtype=dtype) + X.fill(np.finfo(dtype).max) + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + assert_all_finite(X, allow_nan=allow_nan) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [65539], # 2**16 + 3, + [1000, 1000], + [ + 3, + ], + ], +) +@pytest.mark.parametrize("allow_nan", [False, True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") +) +def test_assert_finite_random_location( + dtype, shape, allow_nan, check, seed, dataframe, queue +): + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X.reshape((-1,))[loc] = float(check) + + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + + if check is None or (allow_nan and check == "NaN"): + assert_all_finite(X, allow_nan=allow_nan) + else: + msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." + with pytest.raises(ValueError, match=msg_err): + assert_all_finite(X, allow_nan=allow_nan) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("allow_nan", [False, True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl") +) +def test_assert_finite_random_shape_and_location( + dtype, allow_nan, check, seed, dataframe, queue +): + lb, ub = 2, 1048576 # ub is 2^20 + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X[loc] = float(check) + + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + + if check is None or (allow_nan and check == "NaN"): + assert_all_finite(X, allow_nan=allow_nan) + else: + msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." + with pytest.raises(ValueError, match=msg_err): + assert_all_finite(X, allow_nan=allow_nan) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("allow_nan", [False, True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +def test_assert_finite_sparse(dtype, allow_nan, check, seed): + lb, ub = 2, 2056 + rand.seed(seed) + X = sp.random( + rand.randint(lb, ub), + rand.randint(lb, ub), + format="csr", + dtype=dtype, + random_state=rand.default_rng(seed), + ) + + if check: + locx = rand.randint(0, X.data.shape[0] - 1) + X.data[locx] = float(check) + + if check is None or (allow_nan and check == "NaN"): + assert_all_finite(X, allow_nan=allow_nan) + else: + msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "." + with pytest.raises(ValueError, match=msg_err): + assert_all_finite(X, allow_nan=allow_nan) diff --git a/onedal/utils/validation.py b/onedal/utils/validation.py index 7559c43e4a..145e44b107 100644 --- a/onedal/utils/validation.py +++ b/onedal/utils/validation.py @@ -31,7 +31,12 @@ from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array -from daal4py.sklearn.utils.validation import _assert_all_finite +from daal4py.sklearn.utils.validation import ( + _assert_all_finite as _daal4py_assert_all_finite, +) +from onedal import _backend +from onedal.common._policy import _get_policy +from onedal.datatypes import _convert_to_supported, to_table class DataConversionWarning(UserWarning): @@ -135,10 +140,10 @@ def _check_array( if force_all_finite: if sp.issparse(array): if hasattr(array, "data"): - _assert_all_finite(array.data) + _daal4py_assert_all_finite(array.data) force_all_finite = False else: - _assert_all_finite(array) + _daal4py_assert_all_finite(array) force_all_finite = False array = check_array( array=array, @@ -191,7 +196,7 @@ def _check_X_y( if y_numeric and y.dtype.kind == "O": y = y.astype(np.float64) if force_all_finite: - _assert_all_finite(y) + _daal4py_assert_all_finite(y) lengths = [X.shape[0], y.shape[0]] uniques = np.unique(lengths) @@ -276,7 +281,7 @@ def _type_of_target(y): # check float and contains non-integer float values if y.dtype.kind == "f" and np.any(y != y.astype(int)): # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] - _assert_all_finite(y) + _daal4py_assert_all_finite(y) return "continuous" + suffix if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): @@ -429,3 +434,31 @@ def _is_csr(x): return isinstance(x, sp.csr_matrix) or ( hasattr(sp, "csr_array") and isinstance(x, sp.csr_array) ) + + +def _assert_all_finite(X, allow_nan=False, input_name=""): + policy = _get_policy(None, X) + X_t = to_table(_convert_to_supported(policy, X)) + params = { + "fptype": X_t.dtype, + "method": "dense", + "allow_nan": allow_nan, + } + if not _backend.finiteness_checker.compute.compute(policy, params, X_t).finite: + type_err = "infinity" if allow_nan else "NaN, infinity" + padded_input_name = input_name + " " if input_name else "" + msg_err = f"Input {padded_input_name}contains {type_err}." + raise ValueError(msg_err) + + +def assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + _assert_all_finite( + X.data if sp.issparse(X) else X, + allow_nan=allow_nan, + input_name=input_name, + ) diff --git a/sklearnex/utils/tests/test_finite.py b/sklearnex/utils/tests/test_finite.py index 2874ec3400..7d83667699 100644 --- a/sklearnex/utils/tests/test_finite.py +++ b/sklearnex/utils/tests/test_finite.py @@ -37,7 +37,7 @@ ) @pytest.mark.parametrize("allow_nan", [False, True]) def test_sum_infinite_actually_finite(dtype, shape, allow_nan): - X = np.array(shape, dtype=dtype) + X = np.empty(shape, dtype=dtype) X.fill(np.finfo(dtype).max) _assert_all_finite(X, allow_nan=allow_nan) @@ -48,7 +48,7 @@ def test_sum_infinite_actually_finite(dtype, shape, allow_nan): [ [16, 2048], [ - 2**16 + 3, + 65539, # 2**16 + 3, ], [1000, 1000], ],