Skip to content

Commit 08215a1

Browse files
ueshinHyukjinKwon
authored andcommitted
Fix groupby and its functions to support multi-index columns. (#833)
1 parent f54dba4 commit 08215a1

File tree

3 files changed

+321
-39
lines changed

3 files changed

+321
-39
lines changed

databricks/koalas/generic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,11 +1255,13 @@ def groupby(self, by, as_index: bool = True):
12551255

12561256
df_or_s = self
12571257
if isinstance(by, str):
1258+
by = [(by,)]
1259+
elif isinstance(by, tuple):
12581260
by = [by]
12591261
elif isinstance(by, Series):
12601262
by = [by]
12611263
elif isinstance(by, Iterable):
1262-
by = list(by)
1264+
by = [key if isinstance(key, (tuple, Series)) else (key,) for key in by]
12631265
else:
12641266
raise ValueError('Not a valid index: TODO')
12651267
if not len(by):
@@ -1421,7 +1423,7 @@ def _resolve_col(kdf, col_like):
14211423
assert kdf is col_like._kdf, \
14221424
"Cannot combine column argument because it comes from a different dataframe"
14231425
return col_like
1424-
elif isinstance(col_like, str):
1426+
elif isinstance(col_like, tuple):
14251427
return kdf[col_like]
14261428
else:
14271429
raise ValueError(col_like)

databricks/koalas/groupby.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_MissingPandasLikeSeriesGroupBy
4040
from databricks.koalas.series import Series, _col
4141
from databricks.koalas.config import get_option
42+
from databricks.koalas.utils import column_index_level, scol_for
4243

4344

4445
class GroupBy(object):
@@ -113,12 +114,12 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
113114
114115
"""
115116
if not isinstance(func_or_funcs, dict) or \
116-
not all(isinstance(key, str) and
117+
not all(isinstance(key, (str, tuple)) and
117118
(isinstance(value, str) or
118119
isinstance(value, list) and all(isinstance(v, str) for v in value))
119120
for key, value in func_or_funcs.items()):
120-
raise ValueError("aggs must be a dict mapping from column name (string) to aggregate "
121-
"functions (string or list of strings).")
121+
raise ValueError("aggs must be a dict mapping from column name (string or tuple) to "
122+
"aggregate functions (string or list of strings).")
122123

123124
kdf = DataFrame(GroupBy._spark_groupby(self._kdf, func_or_funcs, self._groupkeys))
124125
if not self._as_index:
@@ -137,27 +138,28 @@ def _spark_groupby(kdf, func, groupkeys):
137138
data_columns = []
138139
column_index = []
139140
for key, value in func.items():
141+
idx = key if isinstance(key, tuple) else (key,)
140142
for aggfunc in [value] if isinstance(value, str) else value:
141-
data_col = "('{0}', '{1}')".format(key, aggfunc) if multi_aggs else key
143+
name = kdf._internal.column_name_for(idx)
144+
data_col = "('{0}', '{1}')".format(name, aggfunc) if multi_aggs else name
142145
data_columns.append(data_col)
143-
column_index.append((key, aggfunc))
146+
column_index.append(tuple(list(idx) + [aggfunc]) if multi_aggs else idx)
144147
if aggfunc == "nunique":
145-
reordered.append(F.expr('count(DISTINCT `{0}`) as `{1}`'.format(key, data_col)))
148+
reordered.append(
149+
F.expr('count(DISTINCT `{0}`) as `{1}`'.format(name, data_col)))
146150
else:
147-
reordered.append(F.expr('{1}(`{0}`) as `{2}`'.format(key, aggfunc, data_col)))
151+
reordered.append(F.expr('{1}(`{0}`) as `{2}`'.format(name, aggfunc, data_col)))
148152
sdf = sdf.groupby(*groupkey_cols).agg(*reordered)
149153
if len(groupkeys) > 0:
150154
index_map = [('__index_level_{}__'.format(i),
151155
s._internal.column_index[0])
152156
for i, s in enumerate(groupkeys)]
153-
return _InternalFrame(sdf=sdf,
154-
data_columns=data_columns,
155-
column_index=column_index if multi_aggs else None,
156-
index_map=index_map)
157157
else:
158-
return _InternalFrame(sdf=sdf,
159-
data_columns=data_columns,
160-
column_index=column_index if multi_aggs else None)
158+
index_map = None
159+
return _InternalFrame(sdf=sdf,
160+
data_columns=data_columns,
161+
column_index=column_index,
162+
index_map=index_map)
161163

162164
def count(self):
163165
"""
@@ -637,7 +639,10 @@ def cumprod(scol):
637639
# `SeriesGroupBy.cumprod`, `SeriesGroupBy._cum` and `Series._cum`
638640
#
639641
# This is a bit hacky. Maybe we should fix it.
640-
@pandas_udf(returnType=self._ks._kdf._internal.spark_type_for(self._ks.name))
642+
643+
return_type = self._ks._internal.spark_type_for(self._ks._internal.column_index[0])
644+
645+
@pandas_udf(returnType=return_type)
641646
def negative_check(s):
642647
assert len(s) == 0 or ((s > 0) | (s.isnull())).all(), \
643648
"values should be bigger than 0: %s" % s
@@ -885,6 +890,7 @@ def _spark_group_map_apply(self, func, return_schema, retain_index):
885890
index_columns = self._kdf._internal.index_columns
886891
index_names = self._kdf._internal.index_names
887892
data_columns = self._kdf._internal.data_columns
893+
column_index = self._kdf._internal.column_index
888894

889895
def rename_output(pdf):
890896
# TODO: This logic below was borrowed from `DataFrame.pandas_df` to set the index
@@ -899,11 +905,14 @@ def rename_output(pdf):
899905
append = True
900906
pdf = pdf[data_columns]
901907

908+
if column_index_level(column_index) > 1:
909+
pdf.columns = pd.MultiIndex.from_tuples(column_index)
910+
else:
911+
pdf.columns = [None if idx is None else idx[0] for idx in column_index]
912+
902913
if len(index_names) > 0:
903-
if isinstance(pdf.index, pd.MultiIndex):
904-
pdf.index.names = index_names
905-
else:
906-
pdf.index.name = index_names[0]
914+
pdf.index.names = [name if name is None or len(name) > 1 else name[0]
915+
for name in index_names]
907916

908917
pdf = func(pdf)
909918

@@ -1069,17 +1078,23 @@ def idxmax(self, skipna=True):
10691078

10701079
stat_exprs = []
10711080
for ks in self._agg_columns:
1081+
name = ks._internal.data_columns[0]
1082+
10721083
if skipna:
10731084
order_column = Column(ks._scol._jc.desc_nulls_last())
10741085
else:
10751086
order_column = Column(ks._scol._jc.desc_nulls_first())
10761087
window = Window.partitionBy(groupkey_cols).orderBy(order_column)
1077-
sdf = sdf.withColumn(ks.name, F.when(F.row_number().over(window) == 1, F.col(index))
1088+
sdf = sdf.withColumn(name,
1089+
F.when(F.row_number().over(window) == 1, scol_for(sdf, index))
10781090
.otherwise(None))
1079-
stat_exprs.append(F.max(F.col(ks.name)).alias(ks.name))
1091+
stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
10801092
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
10811093
internal = _InternalFrame(sdf=sdf,
1082-
data_columns=[ks.name for ks in self._agg_columns],
1094+
data_columns=[ks._internal.data_columns[0]
1095+
for ks in self._agg_columns],
1096+
column_index=[ks._internal.column_index[0]
1097+
for ks in self._agg_columns],
10831098
index_map=[('__index_level_{}__'.format(i),
10841099
s._internal.column_index[0])
10851100
for i, s in enumerate(groupkeys)])
@@ -1133,17 +1148,23 @@ def idxmin(self, skipna=True):
11331148

11341149
stat_exprs = []
11351150
for ks in self._agg_columns:
1151+
name = ks._internal.data_columns[0]
1152+
11361153
if skipna:
11371154
order_column = Column(ks._scol._jc.asc_nulls_last())
11381155
else:
11391156
order_column = Column(ks._scol._jc.asc_nulls_first())
11401157
window = Window.partitionBy(groupkey_cols).orderBy(order_column)
1141-
sdf = sdf.withColumn(ks.name, F.when(F.row_number().over(window) == 1, F.col(index))
1158+
sdf = sdf.withColumn(name,
1159+
F.when(F.row_number().over(window) == 1, scol_for(sdf, index))
11421160
.otherwise(None))
1143-
stat_exprs.append(F.max(F.col(ks.name)).alias(ks.name))
1161+
stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
11441162
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
11451163
internal = _InternalFrame(sdf=sdf,
1146-
data_columns=[ks.name for ks in self._agg_columns],
1164+
data_columns=[ks._internal.data_columns[0]
1165+
for ks in self._agg_columns],
1166+
column_index=[ks._internal.column_index[0]
1167+
for ks in self._agg_columns],
11471168
index_map=[('__index_level_{}__'.format(i),
11481169
s._internal.column_index[0])
11491170
for i, s in enumerate(groupkeys)])
@@ -1666,29 +1687,37 @@ def _reduce_for_stat_function(self, sfun, only_numeric):
16661687
sdf = self._kdf._sdf
16671688

16681689
data_columns = []
1690+
column_index = []
16691691
if len(self._agg_columns) > 0:
16701692
stat_exprs = []
16711693
for ks in self._agg_columns:
16721694
spark_type = ks.spark_type
1695+
name = ks._internal.data_columns[0]
1696+
idx = ks._internal.column_index[0]
16731697
# TODO: we should have a function that takes dataframes and converts the numeric
16741698
# types. Converting the NaNs is used in a few places, it should be in utils.
16751699
# Special handle floating point types because Spark's count treats nan as a valid
16761700
# value, whereas Pandas count doesn't include nan.
16771701
if isinstance(spark_type, DoubleType) or isinstance(spark_type, FloatType):
1678-
stat_exprs.append(sfun(F.nanvl(ks._scol, F.lit(None))).alias(ks.name))
1679-
data_columns.append(ks.name)
1702+
stat_exprs.append(sfun(F.nanvl(ks._scol, F.lit(None))).alias(name))
1703+
data_columns.append(name)
1704+
column_index.append(idx)
16801705
elif isinstance(spark_type, NumericType) or not only_numeric:
1681-
stat_exprs.append(sfun(ks._scol).alias(ks.name))
1682-
data_columns.append(ks.name)
1706+
stat_exprs.append(sfun(ks._scol).alias(name))
1707+
data_columns.append(name)
1708+
column_index.append(idx)
16831709
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
16841710
else:
16851711
sdf = sdf.select(*groupkey_cols).distinct()
16861712
sdf = sdf.sort(*groupkey_cols)
1713+
16871714
internal = _InternalFrame(sdf=sdf,
1688-
data_columns=data_columns,
16891715
index_map=[('__index_level_{}__'.format(i),
16901716
s._internal.column_index[0])
1691-
for i, s in enumerate(groupkeys)])
1717+
for i, s in enumerate(groupkeys)],
1718+
data_columns=data_columns,
1719+
column_index=column_index,
1720+
column_index_names=self._kdf._internal.column_index_names)
16921721
kdf = DataFrame(internal)
16931722
if not self._as_index:
16941723
kdf = kdf.reset_index()
@@ -1708,7 +1737,7 @@ def __init__(self, kdf: DataFrame, by: List[Series], as_index: bool = True,
17081737
agg_columns = [idx for idx in self._kdf._internal.column_index
17091738
if all(not self._kdf[idx]._equals(key) for key in self._groupkeys)]
17101739
self._have_agg_columns = False
1711-
self._agg_columns = [kdf[col] for col in agg_columns]
1740+
self._agg_columns = [kdf[idx] for idx in agg_columns]
17121741

17131742
def __getattr__(self, item: str) -> Any:
17141743
if hasattr(_MissingPandasLikeDataFrameGroupBy, item):
@@ -1725,11 +1754,13 @@ def __getitem__(self, item):
17251754
else:
17261755
if isinstance(item, str):
17271756
item = [item]
1757+
item = [i if isinstance(i, tuple) else (i,) for i in item]
17281758
if not self._as_index:
17291759
groupkey_names = set(key.name for key in self._groupkeys)
17301760
for i in item:
1731-
if i in groupkey_names:
1732-
raise ValueError("cannot insert {}, already exists".format(i))
1761+
name = str(i) if len(i) > 1 else i[0]
1762+
if name in groupkey_names:
1763+
raise ValueError("cannot insert {}, already exists".format(name))
17331764
return DataFrameGroupBy(self._kdf, self._groupkeys, as_index=self._as_index,
17341765
agg_columns=item)
17351766

0 commit comments

Comments
 (0)