3939 _MissingPandasLikeSeriesGroupBy
4040from databricks .koalas .series import Series , _col
4141from databricks .koalas .config import get_option
42+ from databricks .koalas .utils import column_index_level , scol_for
4243
4344
4445class GroupBy (object ):
@@ -113,12 +114,12 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
113114
114115 """
115116 if not isinstance (func_or_funcs , dict ) or \
116- not all (isinstance (key , str ) and
117+ not all (isinstance (key , ( str , tuple ) ) and
117118 (isinstance (value , str ) or
118119 isinstance (value , list ) and all (isinstance (v , str ) for v in value ))
119120 for key , value in func_or_funcs .items ()):
120- raise ValueError ("aggs must be a dict mapping from column name (string) to aggregate "
121- "functions (string or list of strings)." )
121+ raise ValueError ("aggs must be a dict mapping from column name (string or tuple ) to "
122+ "aggregate functions (string or list of strings)." )
122123
123124 kdf = DataFrame (GroupBy ._spark_groupby (self ._kdf , func_or_funcs , self ._groupkeys ))
124125 if not self ._as_index :
@@ -137,27 +138,28 @@ def _spark_groupby(kdf, func, groupkeys):
137138 data_columns = []
138139 column_index = []
139140 for key , value in func .items ():
141+ idx = key if isinstance (key , tuple ) else (key ,)
140142 for aggfunc in [value ] if isinstance (value , str ) else value :
141- data_col = "('{0}', '{1}')" .format (key , aggfunc ) if multi_aggs else key
143+ name = kdf ._internal .column_name_for (idx )
144+ data_col = "('{0}', '{1}')" .format (name , aggfunc ) if multi_aggs else name
142145 data_columns .append (data_col )
143- column_index .append (( key , aggfunc ) )
146+ column_index .append (tuple ( list ( idx ) + [ aggfunc ]) if multi_aggs else idx )
144147 if aggfunc == "nunique" :
145- reordered .append (F .expr ('count(DISTINCT `{0}`) as `{1}`' .format (key , data_col )))
148+ reordered .append (
149+ F .expr ('count(DISTINCT `{0}`) as `{1}`' .format (name , data_col )))
146150 else :
147- reordered .append (F .expr ('{1}(`{0}`) as `{2}`' .format (key , aggfunc , data_col )))
151+ reordered .append (F .expr ('{1}(`{0}`) as `{2}`' .format (name , aggfunc , data_col )))
148152 sdf = sdf .groupby (* groupkey_cols ).agg (* reordered )
149153 if len (groupkeys ) > 0 :
150154 index_map = [('__index_level_{}__' .format (i ),
151155 s ._internal .column_index [0 ])
152156 for i , s in enumerate (groupkeys )]
153- return _InternalFrame (sdf = sdf ,
154- data_columns = data_columns ,
155- column_index = column_index if multi_aggs else None ,
156- index_map = index_map )
157157 else :
158- return _InternalFrame (sdf = sdf ,
159- data_columns = data_columns ,
160- column_index = column_index if multi_aggs else None )
158+ index_map = None
159+ return _InternalFrame (sdf = sdf ,
160+ data_columns = data_columns ,
161+ column_index = column_index ,
162+ index_map = index_map )
161163
162164 def count (self ):
163165 """
@@ -637,7 +639,10 @@ def cumprod(scol):
637639 # `SeriesGroupBy.cumprod`, `SeriesGroupBy._cum` and `Series._cum`
638640 #
639641 # This is a bit hacky. Maybe we should fix it.
640- @pandas_udf (returnType = self ._ks ._kdf ._internal .spark_type_for (self ._ks .name ))
642+
643+ return_type = self ._ks ._internal .spark_type_for (self ._ks ._internal .column_index [0 ])
644+
645+ @pandas_udf (returnType = return_type )
641646 def negative_check (s ):
642647 assert len (s ) == 0 or ((s > 0 ) | (s .isnull ())).all (), \
643648 "values should be bigger than 0: %s" % s
@@ -885,6 +890,7 @@ def _spark_group_map_apply(self, func, return_schema, retain_index):
885890 index_columns = self ._kdf ._internal .index_columns
886891 index_names = self ._kdf ._internal .index_names
887892 data_columns = self ._kdf ._internal .data_columns
893+ column_index = self ._kdf ._internal .column_index
888894
889895 def rename_output (pdf ):
890896 # TODO: This logic below was borrowed from `DataFrame.pandas_df` to set the index
@@ -899,11 +905,14 @@ def rename_output(pdf):
899905 append = True
900906 pdf = pdf [data_columns ]
901907
908+ if column_index_level (column_index ) > 1 :
909+ pdf .columns = pd .MultiIndex .from_tuples (column_index )
910+ else :
911+ pdf .columns = [None if idx is None else idx [0 ] for idx in column_index ]
912+
902913 if len (index_names ) > 0 :
903- if isinstance (pdf .index , pd .MultiIndex ):
904- pdf .index .names = index_names
905- else :
906- pdf .index .name = index_names [0 ]
914+ pdf .index .names = [name if name is None or len (name ) > 1 else name [0 ]
915+ for name in index_names ]
907916
908917 pdf = func (pdf )
909918
@@ -1069,17 +1078,23 @@ def idxmax(self, skipna=True):
10691078
10701079 stat_exprs = []
10711080 for ks in self ._agg_columns :
1081+ name = ks ._internal .data_columns [0 ]
1082+
10721083 if skipna :
10731084 order_column = Column (ks ._scol ._jc .desc_nulls_last ())
10741085 else :
10751086 order_column = Column (ks ._scol ._jc .desc_nulls_first ())
10761087 window = Window .partitionBy (groupkey_cols ).orderBy (order_column )
1077- sdf = sdf .withColumn (ks .name , F .when (F .row_number ().over (window ) == 1 , F .col (index ))
1088+ sdf = sdf .withColumn (name ,
1089+ F .when (F .row_number ().over (window ) == 1 , scol_for (sdf , index ))
10781090 .otherwise (None ))
1079- stat_exprs .append (F .max (F . col ( ks . name )).alias (ks . name ))
1091+ stat_exprs .append (F .max (scol_for ( sdf , name )).alias (name ))
10801092 sdf = sdf .groupby (* groupkey_cols ).agg (* stat_exprs )
10811093 internal = _InternalFrame (sdf = sdf ,
1082- data_columns = [ks .name for ks in self ._agg_columns ],
1094+ data_columns = [ks ._internal .data_columns [0 ]
1095+ for ks in self ._agg_columns ],
1096+ column_index = [ks ._internal .column_index [0 ]
1097+ for ks in self ._agg_columns ],
10831098 index_map = [('__index_level_{}__' .format (i ),
10841099 s ._internal .column_index [0 ])
10851100 for i , s in enumerate (groupkeys )])
@@ -1133,17 +1148,23 @@ def idxmin(self, skipna=True):
11331148
11341149 stat_exprs = []
11351150 for ks in self ._agg_columns :
1151+ name = ks ._internal .data_columns [0 ]
1152+
11361153 if skipna :
11371154 order_column = Column (ks ._scol ._jc .asc_nulls_last ())
11381155 else :
11391156 order_column = Column (ks ._scol ._jc .asc_nulls_first ())
11401157 window = Window .partitionBy (groupkey_cols ).orderBy (order_column )
1141- sdf = sdf .withColumn (ks .name , F .when (F .row_number ().over (window ) == 1 , F .col (index ))
1158+ sdf = sdf .withColumn (name ,
1159+ F .when (F .row_number ().over (window ) == 1 , scol_for (sdf , index ))
11421160 .otherwise (None ))
1143- stat_exprs .append (F .max (F . col ( ks . name )).alias (ks . name ))
1161+ stat_exprs .append (F .max (scol_for ( sdf , name )).alias (name ))
11441162 sdf = sdf .groupby (* groupkey_cols ).agg (* stat_exprs )
11451163 internal = _InternalFrame (sdf = sdf ,
1146- data_columns = [ks .name for ks in self ._agg_columns ],
1164+ data_columns = [ks ._internal .data_columns [0 ]
1165+ for ks in self ._agg_columns ],
1166+ column_index = [ks ._internal .column_index [0 ]
1167+ for ks in self ._agg_columns ],
11471168 index_map = [('__index_level_{}__' .format (i ),
11481169 s ._internal .column_index [0 ])
11491170 for i , s in enumerate (groupkeys )])
@@ -1666,29 +1687,37 @@ def _reduce_for_stat_function(self, sfun, only_numeric):
16661687 sdf = self ._kdf ._sdf
16671688
16681689 data_columns = []
1690+ column_index = []
16691691 if len (self ._agg_columns ) > 0 :
16701692 stat_exprs = []
16711693 for ks in self ._agg_columns :
16721694 spark_type = ks .spark_type
1695+ name = ks ._internal .data_columns [0 ]
1696+ idx = ks ._internal .column_index [0 ]
16731697 # TODO: we should have a function that takes dataframes and converts the numeric
16741698 # types. Converting the NaNs is used in a few places, it should be in utils.
16751699 # Special handle floating point types because Spark's count treats nan as a valid
16761700 # value, whereas Pandas count doesn't include nan.
16771701 if isinstance (spark_type , DoubleType ) or isinstance (spark_type , FloatType ):
1678- stat_exprs .append (sfun (F .nanvl (ks ._scol , F .lit (None ))).alias (ks .name ))
1679- data_columns .append (ks .name )
1702+ stat_exprs .append (sfun (F .nanvl (ks ._scol , F .lit (None ))).alias (name ))
1703+ data_columns .append (name )
1704+ column_index .append (idx )
16801705 elif isinstance (spark_type , NumericType ) or not only_numeric :
1681- stat_exprs .append (sfun (ks ._scol ).alias (ks .name ))
1682- data_columns .append (ks .name )
1706+ stat_exprs .append (sfun (ks ._scol ).alias (name ))
1707+ data_columns .append (name )
1708+ column_index .append (idx )
16831709 sdf = sdf .groupby (* groupkey_cols ).agg (* stat_exprs )
16841710 else :
16851711 sdf = sdf .select (* groupkey_cols ).distinct ()
16861712 sdf = sdf .sort (* groupkey_cols )
1713+
16871714 internal = _InternalFrame (sdf = sdf ,
1688- data_columns = data_columns ,
16891715 index_map = [('__index_level_{}__' .format (i ),
16901716 s ._internal .column_index [0 ])
1691- for i , s in enumerate (groupkeys )])
1717+ for i , s in enumerate (groupkeys )],
1718+ data_columns = data_columns ,
1719+ column_index = column_index ,
1720+ column_index_names = self ._kdf ._internal .column_index_names )
16921721 kdf = DataFrame (internal )
16931722 if not self ._as_index :
16941723 kdf = kdf .reset_index ()
@@ -1708,7 +1737,7 @@ def __init__(self, kdf: DataFrame, by: List[Series], as_index: bool = True,
17081737 agg_columns = [idx for idx in self ._kdf ._internal .column_index
17091738 if all (not self ._kdf [idx ]._equals (key ) for key in self ._groupkeys )]
17101739 self ._have_agg_columns = False
1711- self ._agg_columns = [kdf [col ] for col in agg_columns ]
1740+ self ._agg_columns = [kdf [idx ] for idx in agg_columns ]
17121741
17131742 def __getattr__ (self , item : str ) -> Any :
17141743 if hasattr (_MissingPandasLikeDataFrameGroupBy , item ):
@@ -1725,11 +1754,13 @@ def __getitem__(self, item):
17251754 else :
17261755 if isinstance (item , str ):
17271756 item = [item ]
1757+ item = [i if isinstance (i , tuple ) else (i ,) for i in item ]
17281758 if not self ._as_index :
17291759 groupkey_names = set (key .name for key in self ._groupkeys )
17301760 for i in item :
1731- if i in groupkey_names :
1732- raise ValueError ("cannot insert {}, already exists" .format (i ))
1761+ name = str (i ) if len (i ) > 1 else i [0 ]
1762+ if name in groupkey_names :
1763+ raise ValueError ("cannot insert {}, already exists" .format (name ))
17331764 return DataFrameGroupBy (self ._kdf , self ._groupkeys , as_index = self ._as_index ,
17341765 agg_columns = item )
17351766
0 commit comments