Skip to content

Commit 7b627e7

Browse files
committed
Fix some issues and add unit test
1 parent 71d5ff6 commit 7b627e7

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

xarray/core/indexing.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
is_duck_dask_array,
2020
sparse_array_type,
2121
)
22-
from .utils import maybe_cast_to_coords_dtype
22+
from .utils import maybe_cast_to_coords_dtype, is_duck_array
2323

2424

2525
def expanded_indexer(key, ndim):
@@ -308,7 +308,7 @@ def __init__(self, key):
308308
for k in key:
309309
if isinstance(k, slice):
310310
k = as_integer_slice(k)
311-
elif isinstance(k, np.ndarray) or isinstance(k, da.Array):
311+
elif is_duck_array(k):
312312
if not np.issubdtype(k.dtype, np.integer):
313313
raise TypeError(
314314
f"invalid indexer array, does not have integer dtype: {k!r}"
@@ -321,10 +321,7 @@ def __init__(self, key):
321321
"invalid indexer key: ndarray arguments "
322322
f"have different numbers of dimensions: {ndims}"
323323
)
324-
if isinstance(k, da.Array):
325-
k = da.asarray(k, dtype=np.int64)
326-
else:
327-
k = np.asarray(k, dtype=np.int64)
324+
k = k.astype(np.int64)
328325
else:
329326
raise TypeError(
330327
f"unexpected indexer type for {type(self).__name__}: {k!r}"

xarray/tests/test_indexing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from . import IndexerMaker, ReturnItem, assert_array_equal
1111

12+
da = pytest.importorskip("dask.array")
13+
1214
B = IndexerMaker(indexing.BasicIndexer)
1315

1416

@@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
729731
expected = DataArray(expected_data)
730732

731733
assert [actual.data.item()] == [expected.data.item()]
734+
735+
736+
def test_indexing_dask_array():
737+
da = DataArray(
738+
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
739+
dims=('time', 'x', 'y'),
740+
).chunk(dict(time=-1, x=1, y=1))
741+
da[{"time" : 9}]= 42
742+
743+
idx = da.argmax('time')
744+
actual = da.isel(time=idx)
745+
746+
assert np.all(actual == 42)

0 commit comments

Comments
 (0)