Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions databricks/koalas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from collections.abc import Iterable
from distutils.version import LooseVersion
from functools import reduce
from typing import Optional, Union, List
from typing import List, Optional, Tuple, Union
import warnings

import numpy as np # noqa: F401
import pandas as pd
from pandas.api.types import is_list_like

from pyspark import sql as spark
from pyspark.sql import functions as F
Expand All @@ -38,6 +39,7 @@
from databricks.koalas.spark import functions as SF
from databricks.koalas.utils import (
is_name_like_tuple,
is_name_like_value,
name_like_string,
scol_for,
validate_arguments_and_invoke_function,
Expand Down Expand Up @@ -1521,27 +1523,34 @@ def groupby(self, by, axis=0, as_index: bool = True):

if isinstance(by, ks.DataFrame):
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by)))
elif isinstance(by, str):
elif isinstance(by, ks.Series):
by = [by]
elif is_name_like_tuple(by):
if isinstance(self, ks.Series):
raise KeyError(by)
by = [(by,)]
elif isinstance(by, tuple):
by = [by]
elif is_name_like_value(by):
if isinstance(self, ks.Series):
for key in by:
if isinstance(key, str):
raise KeyError(key)
raise KeyError(by)
by = [(by,)]
elif is_list_like(by):
new_by = [] # type: List[Union[Tuple, ks.Series]]
for key in by:
if isinstance(key, ks.DataFrame):
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(key)))
by = [by]
elif isinstance(by, ks.Series):
by = [by]
elif isinstance(by, Iterable):
if isinstance(self, ks.Series):
for key in by:
if isinstance(key, str):
elif isinstance(key, ks.Series):
new_by.append(key)
elif is_name_like_tuple(key):
if isinstance(self, ks.Series):
raise KeyError(key)
by = [key if isinstance(key, (tuple, ks.Series)) else (key,) for key in by]
new_by.append(key)
elif is_name_like_value(key):
if isinstance(self, ks.Series):
raise KeyError(key)
new_by.append((key,))
else:
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(key)))
by = new_by
else:
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by)))
if not len(by):
Expand Down
27 changes: 17 additions & 10 deletions databricks/koalas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
align_diff_frames,
column_labels_level,
is_name_like_tuple,
is_name_like_value,
name_like_string,
same_anchor,
scol_for,
Expand Down Expand Up @@ -208,7 +209,7 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs):

if not isinstance(func_or_funcs, (str, list)):
if not isinstance(func_or_funcs, dict) or not all(
isinstance(key, (str, tuple))
is_name_like_value(key)
and (
isinstance(value, str)
or isinstance(value, list)
Expand All @@ -217,7 +218,7 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs):
for key, value in func_or_funcs.items()
):
raise ValueError(
"aggs must be a dict mapping from column name (string or tuple) "
"aggs must be a dict mapping from column name "
"to aggregate functions (string or list of strings)."
)

Expand Down Expand Up @@ -2435,18 +2436,24 @@ def __getattr__(self, item: str) -> Any:
return self.__getitem__(item)

def __getitem__(self, item):
if isinstance(item, str) and self._as_index:
return SeriesGroupBy(self._kdf[item], self._groupkeys)
if self._as_index and is_name_like_value(item):
return SeriesGroupBy(
self._kdf._kser_for(item if is_name_like_tuple(item) else (item,)), self._groupkeys
)
else:
if isinstance(item, str):
if is_name_like_tuple(item):
item = [item]
item = [i if isinstance(i, tuple) else (i,) for i in item]
elif is_name_like_value(item):
item = [(item,)]
else:
item = [i if is_name_like_tuple(i) else (i,) for i in item]
if not self._as_index:
groupkey_names = set(key.name for key in self._groupkeys)
for i in item:
name = str(i) if len(i) > 1 else i[0]
groupkey_names = set(key._column_label for key in self._groupkeys)
for name in item:
if name in groupkey_names:
raise ValueError("cannot insert {}, already exists".format(name))
raise ValueError(
"cannot insert {}, already exists".format(name_like_string(name))
)
return DataFrameGroupBy(
self._kdf,
self._groupkeys,
Expand Down
136 changes: 93 additions & 43 deletions databricks/koalas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def test_groupby_simple(self):
self.assertRaises(ValueError, lambda: kdf.groupby("a", as_index=False)["a"])
self.assertRaises(ValueError, lambda: kdf.groupby("a", as_index=False)[["a"]])
self.assertRaises(ValueError, lambda: kdf.groupby("a", as_index=False)[["a", "c"]])
self.assertRaises(ValueError, lambda: kdf.groupby(0, as_index=False)[["a", "c"]])
self.assertRaises(KeyError, lambda: kdf.groupby([0], as_index=False)[["a", "c"]])
self.assertRaises(KeyError, lambda: kdf.groupby("z", as_index=False)[["a", "c"]])
self.assertRaises(KeyError, lambda: kdf.groupby(["z"], as_index=False)[["a", "c"]])

self.assertRaises(TypeError, lambda: kdf.a.groupby(kdf.b, as_index=False))

Expand All @@ -139,53 +139,84 @@ def test_groupby_simple(self):
self.assertRaises(ValueError, lambda: kdf.a.groupby(kdf))
self.assertRaises(ValueError, lambda: kdf.a.groupby((kdf,)))

# non-string names
pdf = pd.DataFrame(
{
10: [1, 2, 6, 4, 4, 6, 4, 3, 7],
20: [4, 2, 7, 3, 3, 1, 1, 1, 2],
30: [4, 2, 7, 3, None, 1, 1, 1, 2],
40: list("abcdefght"),
},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
kdf = ks.from_pandas(pdf)

for as_index in [True, False]:
if as_index:
sort = lambda df: df.sort_index()
else:
sort = lambda df: df.sort_values(10).reset_index(drop=True)
self.assert_eq(
sort(kdf.groupby(10, as_index=as_index).sum()),
sort(pdf.groupby(10, as_index=as_index).sum()),
)
self.assert_eq(
sort(kdf.groupby(10, as_index=as_index)[20].sum()),
sort(pdf.groupby(10, as_index=as_index)[20].sum()),
)
self.assert_eq(
sort(kdf.groupby(10, as_index=as_index)[[20, 30]].sum()),
sort(pdf.groupby(10, as_index=as_index)[[20, 30]].sum()),
)

def test_groupby_multiindex_columns(self):
pdf = pd.DataFrame(
{
("x", "a"): [1, 2, 6, 4, 4, 6, 4, 3, 7],
("x", "b"): [4, 2, 7, 3, 3, 1, 1, 1, 2],
("y", "c"): [4, 2, 7, 3, None, 1, 1, 1, 2],
("z", "d"): list("abcdefght"),
(10, "a"): [1, 2, 6, 4, 4, 6, 4, 3, 7],
(10, "b"): [4, 2, 7, 3, 3, 1, 1, 1, 2],
(20, "c"): [4, 2, 7, 3, None, 1, 1, 1, 2],
(30, "d"): list("abcdefght"),
},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
kdf = ks.from_pandas(pdf)

self.assert_eq(
kdf.groupby(("x", "a")).sum().sort_index(), pdf.groupby(("x", "a")).sum().sort_index()
kdf.groupby((10, "a")).sum().sort_index(), pdf.groupby((10, "a")).sum().sort_index()
)
self.assert_eq(
kdf.groupby(("x", "a"), as_index=False)
kdf.groupby((10, "a"), as_index=False)
.sum()
.sort_values(("x", "a"))
.sort_values((10, "a"))
.reset_index(drop=True),
pdf.groupby(("x", "a"), as_index=False)
pdf.groupby((10, "a"), as_index=False)
.sum()
.sort_values(("x", "a"))
.sort_values((10, "a"))
.reset_index(drop=True),
)
self.assert_eq(
kdf.groupby(("x", "a"))[[("y", "c")]].sum().sort_index(),
pdf.groupby(("x", "a"))[[("y", "c")]].sum().sort_index(),
kdf.groupby((10, "a"))[[(20, "c")]].sum().sort_index(),
pdf.groupby((10, "a"))[[(20, "c")]].sum().sort_index(),
)

# TODO: a pandas bug?
# expected = pdf.groupby((10, "a"))[(20, "c")].sum().sort_index()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit of nit: one more space here in the comment 😅

Copy link
Contributor

@itholic itholic Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, pdf.groupby((10, "a"))[[(20, "c")]].sum() seems working.

>>> pdf
  10      20 30
   a  b    c  d
0  1  4  4.0  a
1  2  2  2.0  b
3  6  7  7.0  c
5  4  3  3.0  d
6  4  3  NaN  e
8  6  1  1.0  f
9  4  1  1.0  g
9  3  1  1.0  h
9  7  2  2.0  t

>>> pdf.groupby((10, "a"))[[(20, "c")]].sum()
          20
           c
(10, a)
1        4.0
2        2.0
3        1.0
4        4.0
6        8.0
7        2.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, Sorry. nvm above. It's different from pdf.groupby((10, "a"))[(20, "c")].sum() anyway.

expected = pd.Series(
[4.0, 2.0, 1.0, 4.0, 8.0, 2.0],
name=(20, "c"),
index=pd.Index([1, 2, 3, 4, 6, 7], name=(10, "a")),
)
# TODO: should work?
# self.assert_eq(
# kdf.groupby(("x", "a"))[("y", "c")].sum().sort_index(),
# pdf.groupby(("x", "a"))[("y", "c")].sum().sort_index(),
# )

self.assert_eq(kdf.groupby((10, "a"))[(20, "c")].sum().sort_index(), expected)

if LooseVersion(pd.__version__) < LooseVersion("1.1.3"):
self.assert_eq(
kdf[("x", "a")].groupby(kdf[("x", "b")]).sum().sort_index(),
pdf[("x", "a")].groupby(pdf[("x", "b")]).sum().sort_index(),
kdf[(20, "c")].groupby(kdf[(10, "a")]).sum().sort_index(),
pdf[(20, "c")].groupby(pdf[(10, "a")]).sum().sort_index(),
)
else:
# seems like a pandas bug introduced in pandas 1.1.3.
expected_result = ks.Series(
[13, 9, 8, 1, 6], name=("x", "a"), index=pd.Index([1, 2, 3, 4, 7], name=("x", "b"))
)
self.assert_eq(
kdf[("x", "a")].groupby(kdf[("x", "b")]).sum().sort_index(), expected_result
)
self.assert_eq(kdf[(20, "c")].groupby(kdf[(10, "a")]).sum().sort_index(), expected)

def test_split_apply_combine_on_series(self):
pdf = pd.DataFrame(
Expand Down Expand Up @@ -382,40 +413,63 @@ def test_aggregate(self):
)

expected_error_message = (
r"aggs must be a dict mapping from column name \(string or "
r"tuple\) to aggregate functions \(string or list of strings\)."
r"aggs must be a dict mapping from column name to aggregate functions "
r"\(string or list of strings\)."
)
with self.assertRaisesRegex(ValueError, expected_error_message):
kdf.groupby("A", as_index=as_index).agg(0)

# multi-index columns
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")])
columns = pd.MultiIndex.from_tuples([(10, "A"), (10, "B"), (20, "C")])
pdf.columns = columns
kdf.columns = columns

for as_index in [True, False]:
stats_kdf = kdf.groupby(("X", "A"), as_index=as_index).agg(
{("X", "B"): "min", ("Y", "C"): "sum"}
stats_kdf = kdf.groupby((10, "A"), as_index=as_index).agg(
{(10, "B"): "min", (20, "C"): "sum"}
)
stats_pdf = pdf.groupby(("X", "A"), as_index=as_index).agg(
{("X", "B"): "min", ("Y", "C"): "sum"}
stats_pdf = pdf.groupby((10, "A"), as_index=as_index).agg(
{(10, "B"): "min", (20, "C"): "sum"}
)
self.assert_eq(
stats_kdf.sort_values(by=[("X", "B"), ("Y", "C")]).reset_index(drop=True),
stats_pdf.sort_values(by=[("X", "B"), ("Y", "C")]).reset_index(drop=True),
stats_kdf.sort_values(by=[(10, "B"), (20, "C")]).reset_index(drop=True),
stats_pdf.sort_values(by=[(10, "B"), (20, "C")]).reset_index(drop=True),
)

stats_kdf = kdf.groupby(("X", "A")).agg({("X", "B"): ["min", "max"], ("Y", "C"): "sum"})
stats_pdf = pdf.groupby(("X", "A")).agg({("X", "B"): ["min", "max"], ("Y", "C"): "sum"})
stats_kdf = kdf.groupby((10, "A")).agg({(10, "B"): ["min", "max"], (20, "C"): "sum"})
stats_pdf = pdf.groupby((10, "A")).agg({(10, "B"): ["min", "max"], (20, "C"): "sum"})
self.assert_eq(
stats_kdf.sort_values(
by=[("X", "B", "min"), ("X", "B", "max"), ("Y", "C", "sum")]
by=[(10, "B", "min"), (10, "B", "max"), (20, "C", "sum")]
).reset_index(drop=True),
stats_pdf.sort_values(
by=[("X", "B", "min"), ("X", "B", "max"), ("Y", "C", "sum")]
by=[(10, "B", "min"), (10, "B", "max"), (20, "C", "sum")]
).reset_index(drop=True),
)

# non-string names
pdf.columns = [10, 20, 30]
kdf.columns = [10, 20, 30]

for as_index in [True, False]:
stats_kdf = kdf.groupby(10, as_index=as_index).agg({20: "min", 30: "sum"})
stats_pdf = pdf.groupby(10, as_index=as_index).agg({20: "min", 30: "sum"})
self.assert_eq(
stats_kdf.sort_values(by=[20, 30]).reset_index(drop=True),
stats_pdf.sort_values(by=[20, 30]).reset_index(drop=True),
)

stats_kdf = kdf.groupby(10).agg({20: ["min", "max"], 30: "sum"})
stats_pdf = pdf.groupby(10).agg({20: ["min", "max"], 30: "sum"})
self.assert_eq(
stats_kdf.sort_values(by=[(20, "min"), (20, "max"), (30, "sum")]).reset_index(
drop=True
),
stats_pdf.sort_values(by=[(20, "min"), (20, "max"), (30, "sum")]).reset_index(
drop=True
),
)

def test_aggregate_func_str_list(self):
# this is test for cases where only string or list is assigned
pdf = pd.DataFrame(
Expand Down Expand Up @@ -2345,10 +2399,6 @@ def test_get_group(self):
self.assertRaises(
KeyError, lambda: kdf.groupby(("B", "class"))[("A", "name")].get_group("fish")
)
self.assertRaises(
KeyError,
lambda: kdf.groupby(("B", "class"))[("A", "name")].get_group(["bird", "mammal"]),
)
Comment on lines -2348 to -2351
Copy link
Contributor

@itholic itholic Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You removed because this is duplicated test with below ?

self.assertRaises(
KeyError,
lambda: kdf.groupby([("B", "class"), ("A", "name")]).get_group(("lion", "mammal")),
Expand Down