Skip to content

Commit 2053f7f

Browse files
committed
da.asarray should not materialize the graph
1 parent 5ef0e18 commit 2053f7f

File tree

4 files changed

+80
-27
lines changed

4 files changed

+80
-27
lines changed

array_api_compat/dask/array/_aliases.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,23 @@ def asarray(
129129
See the corresponding documentation in the array library and/or the array API
130130
specification for more details.
131131
"""
132+
if isinstance(obj, da.Array):
133+
if dtype is not None and dtype != obj.dtype:
134+
if copy is False:
135+
raise ValueError("Unable to avoid copy")
136+
obj = obj.astype(dtype)
137+
return obj.copy() if copy else obj
138+
132139
if copy is False:
133-
# copy=False is not yet implemented in dask
134-
raise NotImplementedError("copy=False is not yet implemented")
135-
elif copy is True:
136-
if isinstance(obj, da.Array) and dtype is None:
137-
return obj.copy()
138-
# Go through numpy, since dask copy is no-op by default
139-
obj = np.array(obj, dtype=dtype, copy=True)
140-
return da.array(obj, dtype=dtype)
141-
else:
142-
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
143-
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
144-
# see https://github.com/dask/dask/pull/11524/
145-
obj = np.array(obj, dtype=dtype, copy=True)
146-
return da.from_array(obj)
147-
return obj
140+
raise NotImplementedError(
141+
"copy=False is not possible when converting a non-dask object to dask"
142+
)
143+
144+
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
145+
# see https://github.com/dask/dask/pull/11524/
146+
obj = np.array(obj, dtype=dtype, copy=True)
147+
return da.from_array(obj)
148148

149-
return da.asarray(obj, dtype=dtype, **kwargs)
150149

151150
from dask.array import (
152151
# Element wise aliases
@@ -184,7 +183,6 @@ def _isscalar(a):
184183
max_shape = () if _isscalar(max) else max.shape
185184

186185
# TODO: This won't handle dask unknown shapes
187-
import numpy as np
188186
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
189187

190188
if min is not None:

tests/test_all.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,7 @@ def test_all(library):
4040
all_names = module.__all__
4141

4242
if set(dir_names) != set(all_names):
43-
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
44-
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
43+
extra_dir = set(dir_names) - set(all_names)
44+
extra_all = set(all_names) - set(dir_names)
45+
assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}"
46+
assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}"

tests/test_common.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ def test_asarray_copy(library):
226226
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
227227

228228
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
229-
supports_copy_false = False
230-
elif library in ['cupy', 'dask.array']:
231-
supports_copy_false = False
229+
supports_copy_false_other_ns = False
230+
supports_copy_false_same_ns = False
231+
elif library == 'cupy':
232+
supports_copy_false_other_ns = False
233+
supports_copy_false_same_ns = False
234+
elif library == 'dask.array':
235+
supports_copy_false_other_ns = False
236+
supports_copy_false_same_ns = True
232237
else:
233-
supports_copy_false = True
238+
supports_copy_false_other_ns = True
239+
supports_copy_false_same_ns = True
234240

235241
a = asarray([1])
236242
b = asarray(a, copy=True)
@@ -240,7 +246,7 @@ def test_asarray_copy(library):
240246
assert all(a[0] == 0)
241247

242248
a = asarray([1])
243-
if supports_copy_false:
249+
if supports_copy_false_same_ns:
244250
b = asarray(a, copy=False)
245251
assert is_lib_func(b)
246252
a[0] = 0
@@ -249,7 +255,7 @@ def test_asarray_copy(library):
249255
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
250256

251257
a = asarray([1])
252-
if supports_copy_false:
258+
if supports_copy_false_same_ns:
253259
pytest.raises(ValueError, lambda: asarray(a, copy=False,
254260
dtype=xp.float64))
255261
else:
@@ -281,7 +287,7 @@ def test_asarray_copy(library):
281287
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
282288
asarray(obj, copy=True) # No error
283289
asarray(obj, copy=None) # No error
284-
if supports_copy_false:
290+
if supports_copy_false_other_ns:
285291
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
286292
else:
287293
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
@@ -294,7 +300,7 @@ def test_asarray_copy(library):
294300
assert all(b[0] == 1.0)
295301

296302
a = array.array('f', [1.0])
297-
if supports_copy_false:
303+
if supports_copy_false_other_ns:
298304
b = asarray(a, copy=False)
299305
assert is_lib_func(b)
300306
a[0] = 0.0

tests/test_dask.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import dask
2+
import numpy as np
3+
import pytest
4+
import dask.array as da
5+
6+
from array_api_compat import array_namespace
7+
8+
9+
@pytest.fixture
10+
def no_compute():
11+
"""
12+
Cause the test to raise if at any point anything calls compute() or persist(),
13+
e.g. as it can be triggered implicitly by __bool__, __array__, etc.
14+
"""
15+
def get(dsk, *args, **kwargs):
16+
raise AssertionError("Called compute() or persist()")
17+
18+
with dask.config.set(scheduler=get):
19+
yield
20+
21+
22+
def test_no_compute(no_compute):
23+
"""Test the no_compute fixture"""
24+
a = da.asarray(True)
25+
with pytest.raises(AssertionError, match="Called compute"):
26+
bool(a)
27+
28+
29+
def test_asarray_no_compute(no_compute):
30+
a = da.arange(10)
31+
xp = array_namespace(a) # wrap in array_api_compat.dask.array
32+
33+
xp.asarray(a)
34+
xp.asarray(a, dtype=np.int16)
35+
xp.asarray(a, dtype=a.dtype)
36+
xp.asarray(a, copy=True)
37+
xp.asarray(a, copy=True, dtype=np.int16)
38+
xp.asarray(a, copy=True, dtype=a.dtype)
39+
40+
41+
def test_clip_no_compute(no_compute):
42+
a = da.arange(10)
43+
xp = array_namespace(a) # wrap in array_api_compat.dask.array
44+
45+
xp.clip(a)
46+
xp.clip(a, 1)
47+
xp.clip(a, 1, 8)

0 commit comments

Comments
 (0)