Skip to content

Commit b5059a5

Browse files
fujiisoupshoyer
authored andcommitted
Fix multiindex selection (#2621)
* Fix multiindex selection * Support pandas0.19 * a bugfix * Do remove_unused_levels only once in unstack. * import algos * Remove unused import * Adopt local import
1 parent c2ce5ea commit b5059a5

File tree

5 files changed

+146
-3
lines changed

5 files changed

+146
-3
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ Bug fixes
8383
By `Martin Raspaud <https://github.com/mraspaud>`_.
8484
- Fix parsing of ``_Unsigned`` attribute set by OPENDAP servers. (:issue:`2583`).
8585
By `Deepak Cherian <https://github.com/dcherian>`_
86-
86+
- Fix MultiIndex selection to update label and level (:issue:`2619`).
87+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
8788

8889
.. _whats-new.0.11.0:
8990

xarray/core/dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from . import (
1616
alignment, computation, duck_array_ops, formatting, groupby, indexing, ops,
17-
resample, rolling, utils)
17+
pdcompat, resample, rolling, utils)
1818
from .. import conventions
1919
from ..coding.cftimeindex import _parse_array_of_cftime_strings
2020
from .alignment import align
@@ -2425,6 +2425,12 @@ def stack(self, dimensions=None, **dimensions_kwargs):
24252425

24262426
def _unstack_once(self, dim):
24272427
index = self.get_index(dim)
2428+
# GH2619. For MultiIndex, we need to call remove_unused.
2429+
if LooseVersion(pd.__version__) >= "0.20":
2430+
index = index.remove_unused_levels()
2431+
else: # for pandas 0.19
2432+
index = pdcompat.remove_unused_levels(index)
2433+
24282434
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
24292435

24302436
# take a shortcut in case the MultiIndex was not modified.

xarray/core/indexing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def convert_label_indexer(index, label, index_name='', method=None,
159159
indexer, new_index = index.get_loc_level(
160160
tuple(label.values()), level=tuple(label.keys()))
161161

162+
# GH2619. Raise a KeyError if nothing is chosen
163+
if indexer.dtype.kind == 'b' and indexer.sum() == 0:
164+
raise KeyError('{} not found'.format(label))
165+
162166
elif isinstance(label, tuple) and isinstance(index, pd.MultiIndex):
163167
if _is_nested_tuple(label):
164168
indexer = index.get_locs(label)
@@ -168,7 +172,6 @@ def convert_label_indexer(index, label, index_name='', method=None,
168172
indexer, new_index = index.get_loc_level(
169173
label, level=list(range(len(label)))
170174
)
171-
172175
else:
173176
label = (label if getattr(label, 'ndim', 1) > 1 # vectorized-indexing
174177
else _asarray_tuplesafe(label))

xarray/core/pdcompat.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# The remove_unused_levels defined here was copied based on the source code
2+
# defined in pandas.core.indexes.muli.py
3+
4+
# For reference, here is a copy of the pandas copyright notice:
5+
6+
# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
7+
# All rights reserved.
8+
9+
# Copyright (c) 2008-2011 AQR Capital Management, LLC
10+
# All rights reserved.
11+
12+
# Redistribution and use in source and binary forms, with or without
13+
# modification, are permitted provided that the following conditions are
14+
# met:
15+
16+
# * Redistributions of source code must retain the above copyright
17+
# notice, this list of conditions and the following disclaimer.
18+
19+
# * Redistributions in binary form must reproduce the above
20+
# copyright notice, this list of conditions and the following
21+
# disclaimer in the documentation and/or other materials provided
22+
# with the distribution.
23+
24+
# * Neither the name of the copyright holder nor the names of any
25+
# contributors may be used to endorse or promote products derived
26+
# from this software without specific prior written permission.
27+
28+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
29+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
30+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
31+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
32+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
33+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
34+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
35+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
36+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
37+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
38+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
39+
40+
41+
import numpy as np
42+
import pandas as pd
43+
44+
45+
# for pandas 0.19
46+
def remove_unused_levels(self):
47+
"""
48+
create a new MultiIndex from the current that removing
49+
unused levels, meaning that they are not expressed in the labels
50+
The resulting MultiIndex will have the same outward
51+
appearance, meaning the same .values and ordering. It will also
52+
be .equals() to the original.
53+
.. versionadded:: 0.20.0
54+
Returns
55+
-------
56+
MultiIndex
57+
Examples
58+
--------
59+
>>> i = pd.MultiIndex.from_product([range(2), list('ab')])
60+
MultiIndex(levels=[[0, 1], ['a', 'b']],
61+
labels=[[0, 0, 1, 1], [0, 1, 0, 1]])
62+
>>> i[2:]
63+
MultiIndex(levels=[[0, 1], ['a', 'b']],
64+
labels=[[1, 1], [0, 1]])
65+
The 0 from the first level is not represented
66+
and can be removed
67+
>>> i[2:].remove_unused_levels()
68+
MultiIndex(levels=[[1], ['a', 'b']],
69+
labels=[[0, 0], [0, 1]])
70+
"""
71+
import pandas.core.algorithms as algos
72+
73+
new_levels = []
74+
new_labels = []
75+
76+
changed = False
77+
for lev, lab in zip(self.levels, self.labels):
78+
79+
# Since few levels are typically unused, bincount() is more
80+
# efficient than unique() - however it only accepts positive values
81+
# (and drops order):
82+
uniques = np.where(np.bincount(lab + 1) > 0)[0] - 1
83+
has_na = int(len(uniques) and (uniques[0] == -1))
84+
85+
if len(uniques) != len(lev) + has_na:
86+
# We have unused levels
87+
changed = True
88+
89+
# Recalculate uniques, now preserving order.
90+
# Can easily be cythonized by exploiting the already existing
91+
# "uniques" and stop parsing "lab" when all items are found:
92+
uniques = algos.unique(lab)
93+
if has_na:
94+
na_idx = np.where(uniques == -1)[0]
95+
# Just ensure that -1 is in first position:
96+
uniques[[0, na_idx[0]]] = uniques[[na_idx[0], 0]]
97+
98+
# labels get mapped from uniques to 0:len(uniques)
99+
# -1 (if present) is mapped to last position
100+
label_mapping = np.zeros(len(lev) + has_na)
101+
# ... and reassigned value -1:
102+
label_mapping[uniques] = np.arange(len(uniques)) - has_na
103+
104+
lab = label_mapping[lab]
105+
106+
# new levels are simple
107+
lev = lev.take(uniques[has_na:])
108+
109+
new_levels.append(lev)
110+
new_labels.append(lab)
111+
112+
result = self._shallow_copy()
113+
114+
if changed:
115+
result._reset_identity()
116+
result._set_levels(new_levels, validate=False)
117+
result._set_labels(new_labels, validate=False)
118+
119+
return result

xarray/tests/test_dataarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,20 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
10271027
assert_identical(mdata.sel(x={'one': 'a', 'two': 1}),
10281028
mdata.sel(one='a', two=1))
10291029

1030+
def test_selection_multiindex(self):
1031+
# GH2619. For MultiIndex, we need to call remove_unused.
1032+
ds = xr.DataArray(np.arange(40).reshape(8, 5), dims=['x', 'y'],
1033+
coords={'x': np.arange(8), 'y': np.arange(5)})
1034+
ds = ds.stack(xy=['x', 'y'])
1035+
ds_isel = ds.isel(xy=ds['x'] < 4)
1036+
with pytest.raises(KeyError):
1037+
ds_isel.sel(x=5)
1038+
1039+
actual = ds_isel.unstack()
1040+
expected = ds.reset_index('xy').isel(xy=ds['x'] < 4)
1041+
expected = expected.set_index(xy=['x', 'y']).unstack()
1042+
assert_identical(expected, actual)
1043+
10301044
def test_virtual_default_coords(self):
10311045
array = DataArray(np.zeros((5,)), dims='x')
10321046
expected = DataArray(range(5), dims='x', name='x')

0 commit comments

Comments
 (0)