Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions onedal/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,11 @@
If True, allows to fallback computation to sklearn after onedal
backend in case of runtime error on onedal backend computations.
Global default: True.
use_raw_input:
If True, uses the raw input data in some SPMD onedal backend computations
without any checks on data consistency or validity.
Note: This option is not recommended for general use.
Global default: False.
"""
_default_global_config = {
"target_offload": "auto",
"allow_fallback_to_host": False,
"allow_sklearn_after_onedal": True,
"use_raw_input": False,
}

_threadlocal = threading.local()
Expand Down
25 changes: 1 addition & 24 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
# ==============================================================================

import inspect
import logging
from functools import wraps
from operator import xor

import numpy as np
from sklearn import get_config

from ._config import _get_config
from .datatypes import copy_to_dpnp, dlpack_to_numpy
from .utils import _sycl_queue_manager as QM
from .utils._array_api import _asarray, _get_sycl_namespace, _is_numpy_namespace
from .utils._third_party import is_dpnp_ndarray

logger = logging.getLogger("sklearnex")


def supports_queue(func):
"""Decorator that updates the global queue before function evaluation.
Expand Down Expand Up @@ -126,26 +122,7 @@ def wrapper_impl(*args, **kwargs):
else:
self = None

# KNeighbors*.fit can not be used with raw inputs, ignore `use_raw_input=True`
override_raw_input = (
self
and self.__class__.__name__ in ("KNeighborsClassifier", "KNeighborsRegressor")
and func.__name__ == "fit"
and _get_config()["use_raw_input"] is True
)
if override_raw_input:
pretty_name = f"{self.__class__.__name__}.{func.__name__}"
logger.warning(
f"Using raw inputs is not supported for {pretty_name}. Ignoring `use_raw_input=True` setting."
)
if _get_config()["use_raw_input"] is True and not override_raw_input:
if "queue" not in kwargs:
if usm_iface := getattr(args[0], "__sycl_usm_array_interface__", None):
kwargs["queue"] = usm_iface["syclobj"]
else:
kwargs["queue"] = None
return invoke_func(self, *args, **kwargs)
elif len(args) == 0 and len(kwargs) == 0:
if len(args) == 0 and len(kwargs) == 0:
# no arguments, there's nothing we can deduce from them -> just call the function
return invoke_func(self, *args, **kwargs)

Expand Down
27 changes: 0 additions & 27 deletions sklearnex/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# ==============================================================================

import sys
import warnings
from contextlib import contextmanager

from sklearn import get_config as skl_get_config
Expand Down Expand Up @@ -54,26 +53,10 @@
{tab}
{tab} Global default: ``True``.
{tab}
{tab}use_raw_input : bool or None
{tab} If ``True``, uses the raw input data in some SPMD onedal backend computations
{tab} without any checks on data consistency or validity. Note that this can be
{tab} better achieved through usage of :ref:`array API classes <array_api>` without
{tab} ``target_offload``. Not recommended for general use.
{tab}
{tab} Global default: ``False``.
{tab}
{tab} .. deprecated:: 2026.0
{tab}
{tab}sklearn_configs : kwargs
{tab} Other settings accepted by scikit-learn. See :obj:`sklearn.set_config` for
{tab} details.
{tab}
{tab}Warnings
{tab}--------
{tab}Using ``use_raw_input=True`` is not recommended for general use as it
{tab}bypasses data consistency checks, which may lead to unexpected behavior. It is
{tab}recommended to use the newer :ref:`array API <array_api>` instead.
{tab}
{tab}Note
{tab}----
{tab}Usage of ``target_offload`` requires additional dependencies - see
Expand Down Expand Up @@ -102,7 +85,6 @@ def set_config(
target_offload=None,
allow_fallback_to_host=None,
allow_sklearn_after_onedal=None,
use_raw_input=None,
**sklearn_configs,
): # numpydoc ignore=PR01,PR07
"""Set global configuration.
Expand All @@ -125,15 +107,6 @@ def set_config(
local_config["allow_fallback_to_host"] = allow_fallback_to_host
if allow_sklearn_after_onedal is not None:
local_config["allow_sklearn_after_onedal"] = allow_sklearn_after_onedal
if use_raw_input is not None:
if use_raw_input:
warnings.warn(
"The 'use_raw_input' parameter is deprecated and will be removed in version 2026.0. "
"On-device input validation can now be achieved by setting 'array_api_dispatch' to True.",
FutureWarning,
stacklevel=2,
)
local_config["use_raw_input"] = use_raw_input


set_config.__doc__ = set_config.__doc__.replace(
Expand Down
18 changes: 3 additions & 15 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

Depending on support conditions, oneDAL will be called, otherwise it will
fall back to calling scikit-learn. Dispatching to oneDAL can be influenced
by the 'use_raw_input' or 'allow_fallback_to_host' config parameters.
by the 'allow_fallback_to_host' config parameter.

Parameters
----------
Expand Down Expand Up @@ -112,10 +112,6 @@
object types should match for the sklearn and onedal object methods.
"""

if get_config()["use_raw_input"]:
with QM.manage_global_queue(None, *args) as queue:
return branches["onedal"](obj, *args, **kwargs, queue=queue)

# Determine if array_api dispatching is enabled, and if estimator is capable
onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api
sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support
Expand Down Expand Up @@ -163,80 +159,72 @@
return branches["sklearn"](obj, *hostargs, **hostkwargs)


def wrap_output_data(func: Callable) -> Callable:
"""Transform function output to match input format.

Converts and moves the output arrays of the decorated function
to match the input array type and device.

Parameters
----------
func : callable
Function or method which has array data as input.

Returns
-------
wrapper : callable
Wrapped function or method which will return matching format.
"""

@wraps(func)
def wrapper(self, *args, **kwargs) -> Any:
result = func(self, *args, **kwargs)
# In case ARRAY API is enabled the result is already converted to the required type
if _array_api_offload() and get_tags(self).onedal_array_api:
# When transform_output is polars/pandas, sklearn's _set_output
# wrapper calls pl.DataFrame(result) which can't handle GPU arrays.
# Transfer to host so sklearn can wrap into the requested format.
if func.__name__ in ("transform", "fit_transform") and (
get_config().get("transform_output")
not in (
"default",
None,
)
or getattr(self, "_sklearn_output_config", {}).get("transform", "default")
!= "default"
):
_, (result,) = _transfer_to_host(result)
return result
if not (len(args) == 0 and len(kwargs) == 0):
data = (*args, *kwargs.values())[0]
# When transform_output is polars/pandas, sklearn's _set_output
# wrapper calls pl.DataFrame(result) which can't handle GPU arrays.
# Transfer to host so sklearn can wrap into the requested format.
if func.__name__ in ("transform", "fit_transform") and (
get_config().get("transform_output")
not in (
"default",
None,
)
or getattr(self, "_sklearn_output_config", {}).get("transform", "default")
!= "default"
):
_, (result,) = _transfer_to_host(result)
return result
# Remove check for result __sycl_usm_array_interface__ on deprecation of use_raw_inputs
if (
usm_iface := getattr(data, "__sycl_usm_array_interface__", None)
) and not hasattr(result, "__sycl_usm_array_interface__"):
# Skip if result elements are already SYCL arrays
# (e.g. kneighbors tuple from from_table(like=X))
if isinstance(result, (tuple, list)) and all(
hasattr(r, "__sycl_usm_array_interface__") for r in result
):
return result

if usm_iface := getattr(data, "__sycl_usm_array_interface__", None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @yuejiaointel for review.

queue = usm_iface["syclobj"]
return copy_to_dpnp(queue, result)

if get_config().get("transform_output") in ("default", None):
if hasattr(data, "dtype"):
xp, is_array_api = get_namespace(data)
if is_array_api and not _is_numpy_namespace(xp):
device = getattr(data, "device", None)
if isinstance(result, tuple):
result = tuple(xp.asarray(r, device=device) for r in result)
elif not isinstance(result, (int, float)):
result = xp.asarray(result, device=device)
return result

return wrapper
26 changes: 12 additions & 14 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
from onedal.utils.validation import _is_csr

from .._config import get_config
from .._device_offload import dispatch
from .._utils import PatchingConditionsChain
from ..base import oneDALEstimator
Expand Down Expand Up @@ -157,20 +156,19 @@ def _onedal_gpu_supported(self, method_name, *data):
return patching_status

def _onedal_fit(self, X, sample_weight=None, queue=None):
if not get_config()["use_raw_input"]:
xp, _ = get_namespace(X, sample_weight)
X = validate_data(
self,
X,
dtype=[xp.float64, xp.float32],
ensure_2d=False,
accept_sparse="csr",
)
xp, _ = get_namespace(X, sample_weight)
X = validate_data(
self,
X,
dtype=[xp.float64, xp.float32],
ensure_2d=False,
accept_sparse="csr",
)

if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)
if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)

onedal_params = {
"result_options": self.result_options,
Expand Down
17 changes: 7 additions & 10 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
IncrementalBasicStatistics as onedal_IncrementalBasicStatistics,
)

from .._config import get_config
from .._device_offload import dispatch
from .._utils import PatchingConditionsChain, _add_inc_serialization_note
from ..base import oneDALEstimator
Expand Down Expand Up @@ -174,8 +173,7 @@ def _onedal_finalize_fit(self, queue=None):
def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=True):
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0

# never check input when using raw input
if check_input and not get_config()["use_raw_input"]:
if check_input:
xp, _ = get_namespace(X)
X = validate_data(
self,
Expand Down Expand Up @@ -204,14 +202,13 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru
self._need_to_finalize = True

def _onedal_fit(self, X, sample_weight=None, queue=None):
if not get_config()["use_raw_input"]:
xp, _ = get_namespace(X, sample_weight)
X = validate_data(self, X, dtype=[xp.float64, xp.float32])
xp, _ = get_namespace(X, sample_weight)
X = validate_data(self, X, dtype=[xp.float64, xp.float32])

if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)
if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)

_, n_features = X.shape
if self.batch_size is None:
Expand Down
12 changes: 4 additions & 8 deletions sklearnex/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from onedal.cluster import DBSCAN as onedal_DBSCAN
from onedal.utils._array_api import _is_numpy_namespace

from .._config import get_config
from .._device_offload import dispatch
from .._utils import PatchingConditionsChain
from ..base import oneDALEstimator
Expand Down Expand Up @@ -77,14 +76,11 @@ def __init__(

def _onedal_fit(self, X, y, sample_weight=None, queue=None):
xp, _ = get_namespace(X, y, sample_weight)
if not get_config()["use_raw_input"]:
X = validate_data(
self, X, accept_sparse="csr", dtype=[xp.float64, xp.float32]
X = validate_data(self, X, accept_sparse="csr", dtype=[xp.float64, xp.float32])
if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)
if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)

onedal_params = {
"eps": self.eps,
Expand Down
90 changes: 41 additions & 49 deletions sklearnex/cluster/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,28 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):

xp, _ = get_namespace(X)

if not get_config()["use_raw_input"]:
if _is_arraylike_not_scalar(self.init):
init = validate_data(
self,
self.init,
dtype=[xp.float64, xp.float32],
accept_sparse="csr",
copy=True,
order="C",
reset=False,
)
self._validate_center_shape(X, init)
self.init = init

X = validate_data(
if _is_arraylike_not_scalar(self.init):
init = validate_data(
self,
X,
accept_sparse="csr",
self.init,
dtype=[xp.float64, xp.float32],
accept_sparse="csr",
copy=True,
order="C",
copy=self.copy_x,
accept_large_sparse=False,
reset=False,
)
self._validate_center_shape(X, init)
self.init = init

X = validate_data(
self,
X,
accept_sparse="csr",
dtype=[xp.float64, xp.float32],
order="C",
copy=self.copy_x,
accept_large_sparse=False,
)

# Validate critical parameters to match sklearn's _check_params
# behavior, which we bypass in the oneDAL path. This is needed
Expand Down Expand Up @@ -386,16 +385,13 @@ def predict(
def _onedal_predict(self, X, sample_weight=None, queue=None):

xp, _ = get_namespace(X)

if not get_config()["use_raw_input"]:
X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
)

X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
)
return self._onedal_estimator.predict(X, queue=queue)

def _onedal_supported(self, method_name, *data):
Expand Down Expand Up @@ -456,17 +452,15 @@ def transform(self, X):
def _onedal_transform(self, X, queue=None):

xp, is_array_api = get_namespace(X)

if not get_config()["use_raw_input"]:
X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
order="C",
accept_large_sparse=False,
)
X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
order="C",
accept_large_sparse=False,
)

if is_array_api:
centers = xp.asarray(self.cluster_centers_)
Expand Down Expand Up @@ -500,15 +494,13 @@ def score(self, X, y=None, sample_weight=None):
def _onedal_score(self, X, y=None, sample_weight=None, queue=None):

xp, _ = get_namespace(X)

if not get_config()["use_raw_input"]:
X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
)
X = validate_data(
self,
X,
accept_sparse="csr",
reset=False,
dtype=[xp.float64, xp.float32],
)

if not sklearn_check_version("1.5") and sklearn_check_version("1.3"):
if isinstance(sample_weight, str) and sample_weight == "deprecated":
Expand Down
Loading
Loading