diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 1eea282143..43c3da0b9a 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data): raise RuntimeError("Input data shall be located on single target device") host_data.append(item) - return queue, host_data + return has_usm_data, queue, host_data def _get_global_queue(): @@ -157,8 +157,8 @@ def _get_global_queue(): def _get_host_inputs(*args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + _, q, hostargs = _transfer_to_host(q, *args) + _, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) return q, hostargs, hostkwargs diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 06f97aa679..fd65be9c27 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -63,12 +63,12 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs if backend == "onedal": patching_status.write_log(queue=q) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) @@ -78,6 +78,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs): and get_config()["array_api_dispatch"] and "array_api_support" in obj._get_tags() and obj._get_tags()["array_api_support"] + and not has_usm_data ): # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback.