Skip to content

ENH: Support ExtensionArray in Groupby #20502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
DataError, SpecificationError)
from pandas.core.index import (Index, MultiIndex,
CategoricalIndex, _ensure_index)
from pandas.core.arrays import Categorical
from pandas.core.arrays import ExtensionArray, Categorical
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame, _shared_docs
from pandas.core.internals import BlockManager, make_block
Expand Down Expand Up @@ -2968,7 +2968,7 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,

# no level passed
elif not isinstance(self.grouper,
(Series, Index, Categorical, np.ndarray)):
(Series, Index, ExtensionArray, np.ndarray)):
if getattr(self.grouper, 'ndim', 1) != 1:
t = self.name or str(type(self.grouper))
raise ValueError("Grouper for '%s' not 1-dimensional" % t)
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TestMyDtype(BaseDtypeTests):
from .constructors import BaseConstructorsTests # noqa
from .dtype import BaseDtypeTests # noqa
from .getitem import BaseGetitemTests # noqa
from .groupby import BaseGroupbyTests # noqa
from .interface import BaseInterfaceTests # noqa
from .methods import BaseMethodsTests # noqa
from .missing import BaseMissingTests # noqa
Expand Down
69 changes: 69 additions & 0 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

import pandas.util.testing as tm
import pandas as pd
from .base import BaseExtensionTests


class BaseGroupbyTests(BaseExtensionTests):
"""Groupby-specific tests."""

def test_grouping_grouper(self, data_for_grouping):
df = pd.DataFrame({
"A": ["B", "B", None, None, "A", "A", "B", "C"],
"B": data_for_grouping
})
gr1 = df.groupby("A").grouper.groupings[0]
gr2 = df.groupby("B").grouper.groupings[0]

tm.assert_numpy_array_equal(gr1.grouper, df.A.values)
tm.assert_extension_array_equal(gr2.grouper, data_for_grouping)

@pytest.mark.parametrize('as_index', [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping})
result = df.groupby("B", as_index=as_index).A.mean()
_, index = pd.factorize(data_for_grouping, sort=True)
# TODO(ExtensionIndex): remove astype
index = pd.Index(index.astype(object), name="B")
expected = pd.Series([3, 1, 4], index=index, name="A")
if as_index:
self.assert_series_equal(result, expected)
else:
expected = expected.reset_index()
self.assert_frame_equal(result, expected)

def test_groupby_extension_no_sort(self, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping})
result = df.groupby("B", sort=False).A.mean()
_, index = pd.factorize(data_for_grouping, sort=False)
# TODO(ExtensionIndex): remove astype
index = pd.Index(index.astype(object), name="B")
expected = pd.Series([1, 3, 4], index=index, name="A")
self.assert_series_equal(result, expected)

def test_groupby_extension_transform(self, data_for_grouping):
valid = data_for_grouping[~data_for_grouping.isna()]
df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4],
"B": valid})

result = df.groupby("B").A.transform(len)
expected = pd.Series([3, 3, 2, 2, 3, 1], name="A")

self.assert_series_equal(result, expected)

@pytest.mark.parametrize('op', [
lambda x: 1,
lambda x: [1] * len(x),
lambda x: pd.Series([1] * len(x)),
lambda x: x,
], ids=['scalar', 'list', 'series', 'object'])
def test_groupby_extension_apply(self, data_for_grouping, op):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping})
df.groupby("B").apply(op)
df.groupby("B").A.apply(op)
df.groupby("A").apply(op)
df.groupby("A").B.apply(op)
4 changes: 4 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):
pass


class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
pass


def test_series_constructor_coerce_data_to_extension_dtype_raises():
xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the "
"extension array directly.")
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def _concat_same_type(cls, to_concat):
return cls(data)

def _values_for_factorize(self):
frozen = tuple(tuple(x.items()) for x in self)
return np.array(frozen, dtype=object), ()
frozen = self._values_for_argsort()
return frozen, ()

def _values_for_argsort(self):
# Disable NumPy's shape inference by including an empty tuple...
Expand Down
41 changes: 37 additions & 4 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def test_fillna_frame(self):
"""We treat dictionaries as a mapping in fillna, not a scalar."""


class TestMethods(base.BaseMethodsTests):
unhashable = pytest.mark.skip(reason="Unhashable")
unstable = pytest.mark.skipif(not PY36, # 3.6 or higher
reason="Dictionary order unstable")
unhashable = pytest.mark.skip(reason="Unhashable")
unstable = pytest.mark.skipif(not PY36, # 3.6 or higher
reason="Dictionary order unstable")


class TestMethods(base.BaseMethodsTests):
@unhashable
def test_value_counts(self, all_data, dropna):
pass
Expand All @@ -118,6 +119,7 @@ def test_sort_values(self, data_for_sorting, ascending):
super(TestMethods, self).test_sort_values(
data_for_sorting, ascending)

@unstable
@pytest.mark.parametrize('ascending', [True, False])
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
super(TestMethods, self).test_sort_values_missing(
Expand All @@ -126,3 +128,34 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending):

class TestCasting(base.BaseCastingTests):
pass


class TestGroupby(base.BaseGroupbyTests):

@unhashable
def test_groupby_extension_transform(self):
"""
This currently fails in Series.name.setter, since the
name must be hashable, but the value is a dictionary.
I think this is what we want, i.e. `.name` should be the original
values, and not the values for factorization.
"""

@unhashable
def test_groupby_extension_apply(self):
"""
This fails in Index._do_unique_check with

> hash(val)
E TypeError: unhashable type: 'UserDict' with

I suspect that once we support Index[ExtensionArray],
we'll be able to dispatch unique.
"""

@unstable
@pytest.mark.parametrize('as_index', [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
super(TestGroupby, self).test_groupby_extension_agg(
as_index, data_for_grouping
)