Skip to content

Commit 781a555

Browse files
authored
Merge pull request #108 from lucascolley/dlpack-device
BUG: `from_dlpack`: fix default device
2 parents 74d3f7d + f8a6a9e commit 781a555

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

array_api_strict/_creation_functions.py

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def from_dlpack(
226226
# Going to wait for upstream numpy support
227227
if device is not _default:
228228
_check_device(device)
229+
else:
230+
device = None
229231
if copy not in [_default, None]:
230232
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
231233

array_api_strict/tests/test_creation_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,10 @@ def from_dlpack_2023_12(api_version):
236236
pytest.raises(exception, lambda: from_dlpack(capsule, copy=False))
237237
pytest.raises(exception, lambda: from_dlpack(capsule, copy=True))
238238
pytest.raises(exception, lambda: from_dlpack(capsule, copy=None))
239+
240+
241+
def test_from_dlpack_default_device():
242+
x = asarray([1, 2, 3])
243+
y = from_dlpack(x)
244+
z = from_dlpack(np.asarray([1, 2, 3]))
245+
assert x.device == y.device == z.device == CPU_DEVICE

0 commit comments

Comments
 (0)