From 36cafcd5ad81168d19b528ccafb7c3783008b488 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Tue, 15 Oct 2024 05:12:32 -0700 Subject: [PATCH] FIX: update functional support fallback logic a little bit host numpy copies of the inputs data will be used for the fallback cases, since stock scikit-learn doesn't support DPCTL usm_ndarray and DPNP ndarray --- onedal/_device_offload.py | 6 +++--- sklearnex/_device_offload.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 1eea282143..43c3da0b9a 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data): raise RuntimeError("Input data shall be located on single target device") host_data.append(item) - return queue, host_data + return has_usm_data, queue, host_data def _get_global_queue(): @@ -157,8 +157,8 @@ def _get_global_queue(): def _get_host_inputs(*args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + _, q, hostargs = _transfer_to_host(q, *args) + _, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) return q, hostargs, hostkwargs diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 06f97aa679..fd65be9c27 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -63,12 +63,12 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs if backend == "onedal": patching_status.write_log(queue=q) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) @@ -78,6 +78,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs): and get_config()["array_api_dispatch"] and "array_api_support" in obj._get_tags() and obj._get_tags()["array_api_support"] + and not has_usm_data ): # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback.