diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 373cb8d13dc..d7a9554fd40 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -77,6 +77,8 @@ Bug fixes By `Mayeul d'Avezac `_. - Return correct count for scalar datetime64 arrays (:issue:`2770`) By `Dan Nowacki `_. +- Fixed max, min exception when applied to a multiIndex (:issue:`2923`) + By `Ian Castleden `_ - A deep copy deep-copies the coords (:issue:`1463`) By `Martin Pletcher `_. - Increased support for `missing_value` (:issue:`2871`) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index babc1dd97e6..06ab08e12fb 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -1,6 +1,6 @@ import numpy as np -from . import dtypes, nputils +from . import dtypes, nputils, utils from .duck_array_ops import ( _dask_or_eager_func, count, fillna, isnull, where_method) from .pycompat import dask_array_type @@ -64,8 +64,10 @@ def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): filled_value = fillna(value, fill_value) data = getattr(np, func)(filled_value, axis=axis, **kwargs) if not hasattr(data, 'dtype'): # scalar case - data = dtypes.fill_value(value.dtype) if valid_count == 0 else data - return np.array(data, dtype=value.dtype) + data = fill_value if valid_count == 0 else data + # we've computed a single min, max value of type object. + # don't let np.array turn a tuple back into an array + return utils.to_0d_object_array(data) return where_method(data, valid_count != 0) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 59435fea88b..c044d2ed1f3 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -43,6 +43,22 @@ def test_asarray_tuplesafe(self): assert res[0] == (0,) assert res[1] == (1,) + def test_stacked_multiindex_min_max(self): + data = np.random.randn(3, 23, 4) + da = DataArray( + data, name="value", + dims=["replicate", "rsample", "exp"], + coords=dict( + replicate=[0, 1, 2], + exp=["a", "b", "c", "d"], + rsample=list(range(23)) + ), + ) + da2 = da.stack(sample=("replicate", "rsample")) + s = da2.sample + assert_array_equal(da2.loc['a', s.max()], data[2, 22, 0]) + assert_array_equal(da2.loc['b', s.min()], data[0, 0, 1]) + def test_convert_label_indexer(self): # TODO: add tests that aren't just for edge cases index = pd.Index([1, 2, 3])