Skip to content

Commit 29ec7ad

Browse files
authored
Add array_api_dispatch to spmd validation scope (#2940)
* Add array_api_dispatch to spmd validation scope * formatting * Combine onedal _device_offload.py change * re-add transform output logic * black * minor restoration * restore host transfer if no queue * add new array api algos and import config_context * debugging * kmeans hotfix * remove debug
1 parent 3a63935 commit 29ec7ad

File tree

14 files changed

+156
-56
lines changed

14 files changed

+156
-56
lines changed

onedal/_device_offload.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,25 @@ def wrapper_impl(*args, **kwargs):
122122
else:
123123
self = None
124124

125-
if len(args) == 0 and len(kwargs) == 0:
126-
# no arguments, there's nothing we can deduce from them -> just call the function
127-
return invoke_func(self, *args, **kwargs)
125+
if "queue" not in kwargs and "queue" in inspect.signature(func).parameters:
126+
if usm_iface := getattr(args[0], "__sycl_usm_array_interface__", None):
127+
kwargs["queue"] = usm_iface["syclobj"]
128128

129-
data = (*args, *kwargs.values())[0]
130-
# get and set the global queue from the kwarg or data
131-
with QM.manage_global_queue(kwargs.get("queue"), *args) as queue:
132-
hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
133-
if "queue" in inspect.signature(func).parameters:
134-
# set the queue if it's expected by func
135-
hostkwargs["queue"] = queue
136-
result = invoke_func(self, *hostargs, **hostkwargs)
129+
if kwargs.get("queue") is not None:
130+
# Device path — function accepts queue, pass device data directly
131+
result = invoke_func(self, *args, **kwargs)
132+
else:
133+
# Host path — sklearn function or host data, transfer to host
134+
if len(args) == 0 and len(kwargs) == 0:
135+
return invoke_func(self, *args, **kwargs)
137136

138-
if queue and hasattr(data, "__sycl_usm_array_interface__"):
139-
return copy_to_dpnp(queue, result)
137+
with QM.manage_global_queue(None, *args) as queue:
138+
hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
139+
result = invoke_func(self, *hostargs, **hostkwargs)
140+
if queue and hasattr(args[0], "__sycl_usm_array_interface__"):
141+
return copy_to_dpnp(queue, result)
140142

143+
data = (*args, *kwargs.values())[0]
141144
if get_config().get("transform_output") in ("default", None):
142145
input_array_api = getattr(data, "__array_namespace__", lambda: None)()
143146
if input_array_api and not _is_numpy_namespace(input_array_api):

sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_convert_to_dataframe,
2525
get_dataframes_and_queues,
2626
)
27+
from sklearnex import config_context
2728
from sklearnex.tests.utils.spmd import (
2829
_generate_statistic_data,
2930
_get_local_tensor,
@@ -83,8 +84,11 @@ def test_basic_stats_spmd_gold(dataframe, queue):
8384
"dataframe,queue",
8485
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
8586
)
87+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
8688
@pytest.mark.mpi
87-
def test_basic_stats_spmd_synthetic(n_samples, n_features, dataframe, queue, dtype):
89+
def test_basic_stats_spmd_synthetic(
90+
n_samples, n_features, dataframe, queue, dtype, array_api_dispatch
91+
):
8892
# Import spmd and batch algo
8993
from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch
9094
from sklearnex.spmd.basic_statistics import BasicStatistics as BasicStatistics_SPMD
@@ -97,7 +101,9 @@ def test_basic_stats_spmd_synthetic(n_samples, n_features, dataframe, queue, dty
97101
)
98102

99103
# Ensure results of batch algo match spmd
100-
spmd_result = BasicStatistics_SPMD().fit(local_dpt_data)
104+
# Configure array API dispatch status for spmd estimator
105+
with config_context(array_api_dispatch=array_api_dispatch):
106+
spmd_result = BasicStatistics_SPMD().fit(local_dpt_data)
101107
batch_result = BasicStatistics_Batch().fit(data)
102108

103109
tol = 1e-5 if dtype == np.float32 else 1e-7

sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_convert_to_dataframe,
2525
get_dataframes_and_queues,
2626
)
27+
from sklearnex import config_context
2728
from sklearnex.tests.utils.spmd import (
2829
_generate_statistic_data,
2930
_get_local_tensor,
@@ -253,9 +254,17 @@ def test_incremental_basic_statistics_single_option_partial_fit_spmd_gold(
253254
@pytest.mark.parametrize("n_samples", [100, 10000])
254255
@pytest.mark.parametrize("n_features", [10, 100])
255256
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
257+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
256258
@pytest.mark.mpi
257259
def test_incremental_basic_statistics_partial_fit_spmd_synthetic(
258-
dataframe, queue, num_blocks, weighted, n_samples, n_features, dtype
260+
dataframe,
261+
queue,
262+
num_blocks,
263+
weighted,
264+
n_samples,
265+
n_features,
266+
dtype,
267+
array_api_dispatch,
259268
):
260269
# Import spmd and batch algo
261270
from sklearnex.basic_statistics import IncrementalBasicStatistics
@@ -295,9 +304,11 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic(
295304
dpt_weights = _convert_to_dataframe(
296305
split_weights[i], sycl_queue=queue, target_df=dataframe
297306
)
298-
incbs_spmd.partial_fit(
299-
local_dpt_data, sample_weight=local_dpt_weights if weighted else None
300-
)
307+
# Configure array API dispatch for spmd estimator
308+
with config_context(array_api_dispatch=array_api_dispatch):
309+
incbs_spmd.partial_fit(
310+
local_dpt_data, sample_weight=local_dpt_weights if weighted else None
311+
)
301312
incbs.partial_fit(dpt_data, sample_weight=dpt_weights if weighted else None)
302313

303314
for option in options_and_tests:

sklearnex/spmd/cluster/tests/test_dbscan_spmd.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_convert_to_dataframe,
2222
get_dataframes_and_queues,
2323
)
24+
from sklearnex import config_context
2425
from sklearnex.tests.utils.spmd import (
2526
_generate_clustering_data,
2627
_get_local_tensor,
@@ -69,6 +70,7 @@ def test_dbscan_spmd_gold(dataframe, queue):
6970
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
7071
)
7172
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
73+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
7274
@pytest.mark.mpi
7375
def test_dbscan_spmd_synthetic(
7476
n_samples,
@@ -78,6 +80,7 @@ def test_dbscan_spmd_synthetic(
7880
dataframe,
7981
queue,
8082
dtype,
83+
array_api_dispatch,
8184
):
8285
n_features, eps = n_features_and_eps
8386
# Import spmd and batch algo
@@ -93,7 +96,9 @@ def test_dbscan_spmd_synthetic(
9396
)
9497

9598
# Ensure labels from fit of batch algo matches spmd
96-
spmd_model = DBSCAN_SPMD(eps=eps, min_samples=min_samples).fit(local_dpt_data)
99+
# Configure array API dispatch for spmd estimator
100+
with config_context(array_api_dispatch=array_api_dispatch):
101+
spmd_model = DBSCAN_SPMD(eps=eps, min_samples=min_samples).fit(local_dpt_data)
97102
batch_model = DBSCAN_Batch(eps=eps, min_samples=min_samples).fit(data)
98103

99104
_spmd_assert_allclose(spmd_model.labels_, batch_model.labels_)

sklearnex/spmd/cluster/tests/test_kmeans_spmd.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
_convert_to_dataframe,
2323
get_dataframes_and_queues,
2424
)
25+
from sklearnex import config_context
2526
from sklearnex.tests.utils.spmd import (
27+
_as_numpy,
2628
_assert_kmeans_labels_allclose,
2729
_assert_unordered_allclose,
2830
_generate_clustering_data,
@@ -108,9 +110,10 @@ def test_kmeans_spmd_gold(dataframe, queue):
108110
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
109111
)
110112
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
113+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
111114
@pytest.mark.mpi
112115
def test_kmeans_spmd_synthetic(
113-
n_samples, n_features, n_clusters, dataframe, queue, dtype
116+
n_samples, n_features, n_clusters, dataframe, queue, dtype, array_api_dispatch
114117
):
115118
# Import spmd and batch algo
116119
from sklearnex.cluster import KMeans as KMeans_Batch
@@ -129,9 +132,11 @@ def test_kmeans_spmd_synthetic(
129132
)
130133

131134
# Validate KMeans init
132-
spmd_model_init = KMeans_SPMD(n_clusters=n_clusters, max_iter=1, random_state=0).fit(
133-
local_dpt_X_train
134-
)
135+
# Configure array_api_dispatch for spmd estimator
136+
with config_context(array_api_dispatch=array_api_dispatch):
137+
spmd_model_init = KMeans_SPMD(
138+
n_clusters=n_clusters, max_iter=1, random_state=0
139+
).fit(local_dpt_X_train)
135140
batch_model_init = KMeans_Batch(
136141
n_clusters=n_clusters, max_iter=1, random_state=0
137142
).fit(X_train)
@@ -142,9 +147,13 @@ def test_kmeans_spmd_synthetic(
142147
spmd_model = KMeans_SPMD(
143148
n_clusters=n_clusters, init=spmd_model_init.cluster_centers_, random_state=0
144149
)
145-
spmd_model.fit(local_dpt_X_train)
150+
# Configure array_api_dispatch for spmd estimator
151+
with config_context(array_api_dispatch=array_api_dispatch):
152+
spmd_model.fit(local_dpt_X_train)
146153
batch_model = KMeans_Batch(
147-
n_clusters=n_clusters, init=spmd_model_init.cluster_centers_, random_state=0
154+
n_clusters=n_clusters,
155+
init=_as_numpy(spmd_model_init.cluster_centers_),
156+
random_state=0,
148157
).fit(X_train)
149158

150159
atol = 1e-5 if dtype == np.float32 else 1e-7
@@ -162,7 +171,9 @@ def test_kmeans_spmd_synthetic(
162171
# assert_allclose(spmd_model.n_iter_, batch_model.n_iter_, atol=1)
163172

164173
# Ensure predictions of batch algo match spmd
165-
spmd_result = spmd_model.predict(local_dpt_X_test)
174+
# Configure array_api_dispatch for spmd estimator
175+
with config_context(array_api_dispatch=array_api_dispatch):
176+
spmd_result = spmd_model.predict(local_dpt_X_test)
166177
batch_result = batch_model.predict(X_test)
167178

168179
_assert_kmeans_labels_allclose(

sklearnex/spmd/covariance/tests/test_covariance_spmd.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_convert_to_dataframe,
2424
get_dataframes_and_queues,
2525
)
26+
from sklearnex import config_context
2627
from sklearnex.tests.utils.spmd import (
2728
_generate_statistic_data,
2829
_get_local_tensor,
@@ -85,9 +86,10 @@ def test_covariance_spmd_gold(dataframe, queue):
8586
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
8687
)
8788
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
89+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
8890
@pytest.mark.mpi
8991
def test_covariance_spmd_synthetic(
90-
n_samples, n_features, assume_centered, dataframe, queue, dtype
92+
n_samples, n_features, assume_centered, dataframe, queue, dtype, array_api_dispatch
9193
):
9294
# Import spmd and batch algo
9395
from sklearnex.preview.covariance import (
@@ -103,9 +105,10 @@ def test_covariance_spmd_synthetic(
103105
)
104106

105107
# Ensure results of batch algo match spmd
106-
spmd_result = EmpiricalCovariance_SPMD(assume_centered=assume_centered).fit(
107-
local_dpt_data
108-
)
108+
with config_context(array_api_dispatch=array_api_dispatch):
109+
spmd_result = EmpiricalCovariance_SPMD(assume_centered=assume_centered).fit(
110+
local_dpt_data
111+
)
109112
batch_result = EmpiricalCovariance_Batch(assume_centered=assume_centered).fit(data)
110113

111114
atol = 1e-5 if dtype == np.float32 else 1e-7

sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_convert_to_dataframe,
2424
get_dataframes_and_queues,
2525
)
26+
from sklearnex import config_context
2627
from sklearnex.tests.utils.spmd import (
2728
_generate_statistic_data,
2829
_get_local_tensor,
@@ -150,6 +151,7 @@ def test_incremental_covariance_partial_fit_spmd_gold(
150151
"dataframe,queue",
151152
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
152153
)
154+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
153155
@pytest.mark.mpi
154156
def test_incremental_covariance_partial_fit_spmd_synthetic(
155157
n_samples,
@@ -159,6 +161,7 @@ def test_incremental_covariance_partial_fit_spmd_synthetic(
159161
dataframe,
160162
queue,
161163
dtype,
164+
array_api_dispatch,
162165
):
163166
# Import spmd and batch algo
164167
from sklearnex.covariance import IncrementalEmpiricalCovariance
@@ -181,7 +184,9 @@ def test_incremental_covariance_partial_fit_spmd_synthetic(
181184
local_dpt_data = _convert_to_dataframe(
182185
split_local_data[i], sycl_queue=queue, target_df=dataframe
183186
)
184-
inccov_spmd.partial_fit(local_dpt_data)
187+
# Configure array API dispatch status for spmd estimator
188+
with config_context(array_api_dispatch=array_api_dispatch):
189+
inccov_spmd.partial_fit(local_dpt_data)
185190

186191
inccov.fit(dpt_data)
187192

sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_convert_to_dataframe,
2424
get_dataframes_and_queues,
2525
)
26+
from sklearnex import config_context
2627
from sklearnex.tests.utils.spmd import (
2728
_generate_statistic_data,
2829
_get_local_tensor,
@@ -218,6 +219,7 @@ def test_incremental_pca_fit_spmd_random(
218219
@pytest.mark.parametrize("num_samples", [200, 400])
219220
@pytest.mark.parametrize("num_features", [10, 20])
220221
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
222+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
221223
@pytest.mark.mpi
222224
def test_incremental_pca_partial_fit_spmd_random(
223225
dataframe,
@@ -228,6 +230,7 @@ def test_incremental_pca_partial_fit_spmd_random(
228230
num_samples,
229231
num_features,
230232
dtype,
233+
array_api_dispatch,
231234
):
232235
# Import spmd and non-SPMD algo
233236
from sklearnex.preview.decomposition import IncrementalPCA
@@ -252,7 +255,9 @@ def test_incremental_pca_partial_fit_spmd_random(
252255
split_local_X[i], sycl_queue=queue, target_df=dataframe
253256
)
254257
dpt_X = _convert_to_dataframe(X_split[i], sycl_queue=queue, target_df=dataframe)
255-
incpca_spmd.partial_fit(local_dpt_X)
258+
# Configure array API dispatch status for spmd estimator
259+
with config_context(array_api_dispatch=array_api_dispatch):
260+
incpca_spmd.partial_fit(local_dpt_X)
256261
incpca.partial_fit(dpt_X)
257262

258263
for attribute in attributes_to_compare:
@@ -263,7 +268,9 @@ def test_incremental_pca_partial_fit_spmd_random(
263268
err_msg=f"{attribute} is incorrect",
264269
)
265270

266-
y_trans_spmd = incpca_spmd.transform(dpt_X_test)
271+
# Configure array API dispatch status for spmd estimator
272+
with config_context(array_api_dispatch=array_api_dispatch):
273+
y_trans_spmd = incpca_spmd.transform(dpt_X_test)
267274
y_trans = incpca.transform(dpt_X_test)
268275

269276
assert_allclose(_as_numpy(y_trans_spmd), _as_numpy(y_trans), atol=tol)

sklearnex/spmd/decomposition/tests/test_pca_spmd.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_convert_to_dataframe,
2424
get_dataframes_and_queues,
2525
)
26+
from sklearnex import config_context
2627
from sklearnex.tests.utils.spmd import (
2728
_generate_statistic_data,
2829
_get_local_tensor,
@@ -92,9 +93,17 @@ def test_pca_spmd_gold(dataframe, queue):
9293
get_dataframes_and_queues(dataframe_filter_="dpnp", device_filter_="gpu"),
9394
)
9495
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
96+
@pytest.mark.parametrize("array_api_dispatch", [True, False])
9597
@pytest.mark.mpi
9698
def test_pca_spmd_synthetic(
97-
n_samples, n_features, n_components, whiten, dataframe, queue, dtype
99+
n_samples,
100+
n_features,
101+
n_components,
102+
whiten,
103+
dataframe,
104+
queue,
105+
dtype,
106+
array_api_dispatch,
98107
):
99108
# TODO: Resolve issues with batch fallback and lack of support for n_rows_rank < n_cols
100109
if n_components == "mle" or n_components == 3:
@@ -114,7 +123,10 @@ def test_pca_spmd_synthetic(
114123
)
115124

116125
# Ensure results of batch algo match spmd
117-
spmd_result = PCA_SPMD(n_components=n_components, whiten=whiten).fit(local_dpt_data)
126+
with config_context(array_api_dispatch=array_api_dispatch):
127+
spmd_result = PCA_SPMD(n_components=n_components, whiten=whiten).fit(
128+
local_dpt_data
129+
)
118130
batch_result = PCA_Batch(n_components=n_components, whiten=whiten).fit(data)
119131

120132
tol = 1e-3 if dtype == np.float32 else 1e-7

0 commit comments

Comments
 (0)