diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 26894e2a1d2..34654a85430 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -56,6 +56,9 @@ Enhancements - :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the ``loffset`` kwarg just like Pandas. By `Deepak Cherian `_ +- 0d slices of ndarrays are now obtained directly through indexing, rather than + extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel + Wennberg `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index d51da471c8d..02f2644d57b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1142,15 +1142,6 @@ def __init__(self, array): 'Trying to wrap {}'.format(type(array))) self.array = array - def _ensure_ndarray(self, value): - # We always want the result of indexing to be a NumPy array. If it's - # not, then it really should be a 0d array. Doing the coercion here - # instead of inside variable.as_compatible_data makes it less error - # prone. - if not isinstance(value, np.ndarray): - value = utils.to_0d_array(value) - return value - def _indexing_array_and_key(self, key): if isinstance(key, OuterIndexer): array = self.array @@ -1160,7 +1151,10 @@ def _indexing_array_and_key(self, key): key = key.tuple elif isinstance(key, BasicIndexer): array = self.array - key = key.tuple + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#detailed-notes). # noqa + key = key.tuple + (Ellipsis,) else: raise TypeError('unexpected key type: {}'.format(type(key))) @@ -1171,7 +1165,7 @@ def transpose(self, order): def __getitem__(self, key): array, key = self._indexing_array_and_key(key) - return self._ensure_ndarray(array[key]) + return array[key] def __setitem__(self, key, value): array, key = self._indexing_array_and_key(key) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 84813f6c918..d98783fe2dd 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1147,6 +1147,11 @@ def test_getitem_basic(self): assert v_new.dims == ('x', ) assert_array_equal(v_new, v._data[:, 1]) + # test that we obtain a modifiable view when taking a 0d slice + v_new = v[0, 0] + v_new[...] += 99 + assert_array_equal(v_new, v._data[0, 0]) + def test_getitem_with_mask_2d_input(self): v = Variable(('x', 'y'), [[0, 1, 2], [3, 4, 5]]) assert_identical(v._getitem_with_mask(([-1, 0], [1, -1])),