Skip to content

Commit 01d9951

Browse files
committed
finish
1 parent 9cf0157 commit 01d9951

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

xarray/core/indexing.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,6 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
473473
for k in key:
474474
if isinstance(k, slice):
475475
k = as_integer_slice(k)
476-
# elif is_duck_dask_array(k):
477-
# raise ValueError(
478-
# "Vectorized indexing with Dask arrays is not supported. "
479-
# "Please pass a numpy array by calling ``.compute``. "
480-
# "See https://github.com/dask/dask/issues/8958."
481-
# )
482476
elif is_duck_array(k):
483477
if not np.issubdtype(k.dtype, np.integer):
484478
raise TypeError(
@@ -1509,6 +1503,7 @@ def _oindex_get(self, indexer: OuterIndexer):
15091503
return self.array[key]
15101504

15111505
def _vindex_get(self, indexer: VectorizedIndexer):
1506+
_assert_not_chunked_indexer(indexer.tuple)
15121507
array = NumpyVIndexAdapter(self.array)
15131508
return array[indexer.tuple]
15141509

@@ -1620,6 +1615,16 @@ def _apply_vectorized_indexer_dask_wrapper(indices, coord):
16201615
)
16211616

16221617

1618+
def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None:
1619+
if any(is_chunked_array(i) for i in idxr):
1620+
raise ValueError(
1621+
"Cannot index with a chunked array indexer. "
1622+
"Please chunk the array you are indexing first, "
1623+
"and drop any indexed dimension coordinate variables. "
1624+
"Alternatively, call `.compute()` on any chunked arrays in the indexer."
1625+
)
1626+
1627+
16231628
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
16241629
"""Wrap a dask array to support explicit indexing."""
16251630

@@ -1811,6 +1816,7 @@ def _vindex_get(
18111816
| np.datetime64
18121817
| np.timedelta64
18131818
):
1819+
_assert_not_chunked_indexer(indexer.tuple)
18141820
key = self._prepare_key(indexer.tuple)
18151821

18161822
if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional

xarray/tests/test_indexing.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def test_indexing_1d_object_array() -> None:
974974

975975

976976
@requires_dask
977-
def test_indexing_dask_array():
977+
def test_indexing_dask_array() -> None:
978978
import dask.array
979979

980980
da = DataArray(
@@ -988,7 +988,7 @@ def test_indexing_dask_array():
988988

989989

990990
@requires_dask
991-
def test_indexing_dask_array_scalar():
991+
def test_indexing_dask_array_scalar() -> None:
992992
# GH4276
993993
import dask.array
994994

@@ -1002,19 +1002,37 @@ def test_indexing_dask_array_scalar():
10021002

10031003

10041004
@requires_dask
1005-
def test_vectorized_indexing_dask_array():
1005+
def test_vectorized_indexing_dask_array() -> None:
10061006
# https://github.com/pydata/xarray/issues/2511#issuecomment-563330352
10071007
darr = DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
10081008
indexer = DataArray(
10091009
data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
10101010
coords={"y": range(4), "x": range(2)},
10111011
dims=("y", "x"),
10121012
)
1013-
darr[indexer.chunk({"y": 2})]
1013+
expected = darr[indexer]
1014+
1015+
# fails because we can't index pd.Index lazily (yet)
1016+
with pytest.raises(ValueError, match="Cannot index with"):
1017+
with raise_if_dask_computes():
1018+
darr.chunk()[indexer.chunk({"y": 2})]
1019+
1020+
# fails because we can't index pd.Index lazily (yet)
1021+
with pytest.raises(ValueError, match="Cannot index with"):
1022+
with raise_if_dask_computes():
1023+
actual = darr[indexer.chunk({"y": 2})]
1024+
1025+
with raise_if_dask_computes():
1026+
actual = darr.drop_vars("z").chunk()[indexer.chunk({"y": 2})]
1027+
assert_identical(actual, expected.drop_vars("z"))
1028+
1029+
with raise_if_dask_computes():
1030+
actual = darr.variable.chunk()[indexer.variable.chunk({"y": 2})]
1031+
assert_identical(actual, expected.variable)
10141032

10151033

10161034
@requires_dask
1017-
def test_advanced_indexing_dask_array():
1035+
def test_advanced_indexing_dask_array() -> None:
10181036
# GH4663
10191037
import dask.array as da
10201038

0 commit comments

Comments
 (0)