Skip to content

Commit 2ff7b4c

Browse files
authored
Support indexing with 0d-np.ndarray (#1922)
1 parent e0621c7 commit 2ff7b4c

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ Enhancements
106106

107107
Bug fixes
108108
~~~~~~~~~
109+
- Support indexing with a 0d-np.ndarray (:issue:`1921`).
110+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
109111
- Added warning in api.py of a netCDF4 bug that occurs when
110112
the filepath has 88 characters (:issue:`1745`).
111113
By `Liam Brannigan <https://github.com/braaannigan>`_.

xarray/core/variable.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,14 @@ def _broadcast_indexes(self, key):
463463
key = self._item_key_to_tuple(key) # key is a tuple
464464
# key is a tuple of full size
465465
key = indexing.expanded_indexer(key, self.ndim)
466-
# Convert a scalar Variable as an integer
466+
# Convert a scalar Variable to an integer
467467
key = tuple(
468468
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k
469469
for k in key)
470+
# Convert a 0d-array to an integer
471+
key = tuple(
472+
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k
473+
for k in key)
470474

471475
if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key):
472476
return self._broadcast_indexes_basic(key)

xarray/tests/test_variable.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,12 @@ def test_getitem_0d_array(self):
627627
v_new = v[np.array([0])[0]]
628628
assert_array_equal(v_new, v_data[0])
629629

630+
v_new = v[np.array(0)]
631+
assert_array_equal(v_new, v_data[0])
632+
633+
v_new = v[Variable((), np.array(0))]
634+
assert_array_equal(v_new, v_data[0])
635+
630636
def test_getitem_fancy(self):
631637
v = self.cls(['x', 'y'], [[0, 1, 2], [3, 4, 5]])
632638
v_data = v.compute().data

0 commit comments

Comments
 (0)