Skip to content

Fix DataArray.to_dataframe when the array has MultiIndex #4442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 20, 2021
Merged
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ Bug fixes
a float64 array (:issue:`4898`, :pull:`4911`). By `Blair Bonnett <https://github.com/bcbnz>`_.
- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex``
and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`).
By `ghislainp <https://github.com/ghislainp>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
46 changes: 44 additions & 2 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
cast,
)

import numpy as np
import pandas as pd

from . import formatting, indexing
Expand Down Expand Up @@ -107,8 +108,49 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
return self._data.get_index(dim) # type: ignore
else:
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
names = list(ordered_dims)
return pd.MultiIndex.from_product(indexes, names=names)

# compute the sizes of the repeat and tile for the cartesian product
# (taken from pandas.core.reshape.util)
index_lengths = np.fromiter(
(len(index) for index in indexes), dtype=np.intp
)
cumprod_lengths = np.cumproduct(index_lengths)

if cumprod_lengths[-1] != 0:
# sizes of the repeats
repeat_counts = cumprod_lengths[-1] / cumprod_lengths
else:
# if any factor is empty, the cartesian product is empty
repeat_counts = np.zeros_like(cumprod_lengths)

# sizes of the tiles
tile_counts = np.roll(cumprod_lengths, 1)
tile_counts[0] = 1

# loop over the indexes
# for each MultiIndex or Index compute the cartesian product of the codes

code_list = []
level_list = []
names = []

for i, index in enumerate(indexes):
if isinstance(index, pd.MultiIndex):
codes, levels = index.codes, index.levels
else:
code, level = pd.factorize(index)
codes = [code]
levels = [level]

# compute the cartesian product
code_list += [
np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
for code in codes
]
level_list += levels
names += index.names

return pd.MultiIndex(level_list, code_list, names=names)

def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,6 +3635,33 @@ def test_to_dataframe(self):
with raises_regex(ValueError, "unnamed"):
arr.to_dataframe()

def test_to_dataframe_multiindex(self):
# regression test for #3008
arr_np = np.random.randn(4, 3)

mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])

arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo")

actual = arr.to_dataframe()
assert_array_equal(actual["foo"].values, arr_np.flatten())
assert_array_equal(actual.index.names, list("ABC"))
assert_array_equal(actual.index.levels[0], [1, 2])
assert_array_equal(actual.index.levels[1], ["a", "b"])
assert_array_equal(actual.index.levels[2], [5, 6, 7])

def test_to_dataframe_0length(self):
# regression test for #3008
arr_np = np.random.randn(4, 0)

mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])

arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo")

actual = arr.to_dataframe()
assert len(actual) == 0
assert_array_equal(actual.index.names, list("ABC"))

def test_to_pandas_name_matches_coordinate(self):
# coordinate with same name as array
arr = DataArray([1, 2, 3], dims="x", name="x")
Expand Down