Skip to content

Commit 08d5270

Browse files
committed
ENH: ndonnx device() support; TST: better ndonnx test coverage
1 parent 02fb925 commit 08d5270

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

array_api_compat/common/_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def device(x: Array, /) -> Device:
645645
to_device : Move array data to a different device.
646646
647647
"""
648-
if is_numpy_array(x):
648+
if is_numpy_array(x) or is_ndonnx_array(x):
649649
return "cpu"
650650
elif is_dask_array(x):
651651
# Peek at the metadata of the jax array to determine type
@@ -772,7 +772,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
772772
device : Hardware device the array data resides on.
773773
774774
"""
775-
if is_numpy_array(x):
775+
if is_numpy_array(x) or is_ndonnx_array(x):
776776
if stream is not None:
777777
raise ValueError("The stream argument to to_device() is not supported")
778778
if device == 'cpu':

docs/supported-array-libraries.md

+5
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ The minimum supported Dask version is 2023.12.0.
138138

139139
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
140140

141+
(ndonnx-support)=
142+
## [ndonnx](https://github.com/quantco/ndonnx)
143+
144+
Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.
145+
141146
(array-api-strict-support)=
142147
## [array-api-strict](https://data-apis.org/array-api-strict/)
143148

tests/_helpers.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import pytest
44

55
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]
7-
6+
all_libraries = wrapped_libraries + [
7+
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
8+
]
89

910
def import_(library, wrapper=False):
10-
if library == 'cupy':
11+
if library in ('cupy', 'ndonnx'):
1112
pytest.importorskip(library)
1213
if wrapper:
1314
if 'jax' in library:

tests/test_array_namespace.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def test_array_namespace(library, api_version, use_compat):
2222
if use_compat and library not in wrapped_libraries:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25+
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
26+
pytest.skip("Unsupported API version")
27+
2528
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2629

2730
if use_compat is False or use_compat is None and library not in wrapped_libraries:

tests/test_common.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from array_api_compat import ( # noqa: F401
99
is_numpy_array, is_cupy_array, is_torch_array,
1010
is_dask_array, is_jax_array, is_pydata_sparse_array,
11+
is_ndonnx_array,
1112
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
1213
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
13-
is_array_api_strict_namespace,
14+
is_array_api_strict_namespace, is_ndonnx_namespace,
1415
)
1516

1617
from array_api_compat import (
@@ -25,6 +26,7 @@
2526
'dask.array': 'is_dask_array',
2627
'jax.numpy': 'is_jax_array',
2728
'sparse': 'is_pydata_sparse_array',
29+
'ndonnx': 'is_ndonnx_array',
2830
}
2931

3032
is_namespace_functions = {
@@ -35,6 +37,7 @@
3537
'jax.numpy': 'is_jax_namespace',
3638
'sparse': 'is_pydata_sparse_namespace',
3739
'array_api_strict': 'is_array_api_strict_namespace',
40+
'ndonnx': 'is_ndonnx_namespace',
3841
}
3942

4043

@@ -229,6 +232,13 @@ def _xfail(reason: str) -> None:
229232
# TODO: remove xfail once
230233
# https://github.com/dask/dask/issues/8260 is resolved
231234
_xfail(reason="Bug in dask raising error on conversion")
235+
elif (
236+
source_library == "ndonnx"
237+
and target_library not in ("array_api_strict", "ndonnx", "numpy")
238+
):
239+
_xfail(reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
240+
elif source_library == "ndonnx" and target_library == "numpy":
241+
_xfail(reason="produces numpy array of ndonnx scalar arrays")
232242
elif source_library == "jax.numpy" and target_library == "torch":
233243
_xfail(reason="casts int to float")
234244
elif source_library == "cupy" and target_library != "cupy":

0 commit comments

Comments
 (0)