Skip to content

Commit

Permalink
FIX: update functional support fallback logic a little bit
Browse files Browse the repository at this point in the history
host numpy copies of the inputs data will be used for the fallback cases, since stock scikit-learn doesn't support DPCTL usm_ndarray and DPNP ndarray
  • Loading branch information
samir-nasibli committed Oct 15, 2024
1 parent 89a37a8 commit 36cafcd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
6 changes: 3 additions & 3 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data):
raise RuntimeError("Input data shall be located on single target device")

host_data.append(item)
return queue, host_data
return has_usm_data, queue, host_data


def _get_global_queue():
Expand All @@ -157,8 +157,8 @@ def _get_global_queue():

def _get_host_inputs(*args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
_, q, hostargs = _transfer_to_host(q, *args)
_, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
return q, hostargs, hostkwargs

Expand Down
7 changes: 4 additions & 3 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def _get_backend(obj, queue, method_name, *data):

def dispatch(obj, method_name, branches, *args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args)
has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))

backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)

has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
if backend == "onedal":
patching_status.write_log(queue=q)
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
Expand All @@ -78,6 +78,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs):
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
and not has_usm_data
):
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
# then raw inputs are used for the fallback.
Expand Down

0 comments on commit 36cafcd

Please sign in to comment.