2626from functools import partial , reduce
2727import sys
2828from itertools import zip_longest
29- from typing import Any , Optional , List , Tuple , Union , Generic , TypeVar , Iterable
29+ from typing import Any , Optional , List , Tuple , Union , Generic , TypeVar , Iterable , Dict
3030
3131import numpy as np
3232import pandas as pd
4646from pyspark .sql .functions import pandas_udf , PandasUDFType
4747
4848from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
49- from databricks .koalas .config import get_option
5049from databricks .koalas .utils import validate_arguments_and_invoke_function , align_diff_frames
5150from databricks .koalas .generic import _Frame
5251from databricks .koalas .internal import _InternalFrame , IndexMap
@@ -874,7 +873,7 @@ def applymap(self, func):
874873
875874 # TODO: Series support is not implemented yet.
876875 # TODO: not all arguments are implemented comparing to Pandas' for now.
877- def aggregate (self , func_or_funcs ):
876+ def aggregate (self , func : Union [ List [ str ], Dict [ str , List [ str ]]] ):
878877 """Aggregate using one or more operations over the specified axis.
879878
880879 Parameters
@@ -889,12 +888,6 @@ def aggregate(self, func_or_funcs):
889888 -------
890889 DataFrame
891890
892- The return can be:
893-
894- * DataFrame : when DataFrame.agg is called with several functions
895-
896- Return a DataFrame.
897-
898891 Notes
899892 -----
900893 `agg` is an alias for `aggregate`. Use the alias.
@@ -934,46 +927,48 @@ def aggregate(self, func_or_funcs):
934927 min 1.0 2.0
935928 sum 12.0 NaN
936929 """
937- if isinstance (func_or_funcs , list ):
938- func_or_funcs = dict ([
939- (column , func_or_funcs ) for column in self .columns ])
930+ from databricks .koalas .groupby import GroupBy
940931
941- if not isinstance (func_or_funcs , dict ) or \
932+ if isinstance (func , list ):
933+ if all ((isinstance (f , str ) for f in func )):
934+ func = dict ([
935+ (column , func ) for column in self .columns ])
936+ else :
937+ raise ValueError ("If the given function is a list, it "
938+ "should only contains function names as strings." )
939+
940+ if not isinstance (func , dict ) or \
942941 not all (isinstance (key , str ) and
943942 (isinstance (value , str ) or
944943 isinstance (value , list ) and all (isinstance (v , str ) for v in value ))
945- for key , value in func_or_funcs .items ()):
944+ for key , value in func .items ()):
946945 raise ValueError ("aggs must be a dict mapping from column name (string) to aggregate "
947- "functions (string or list of strings)." )
946+ "functions (list of strings)." )
948947
949- sdf = self ._sdf
950- multi_aggs = any (isinstance (v , list ) for v in func_or_funcs .values ())
951- reordered = []
952- data_columns = []
953- column_index = []
954- for key , value in func_or_funcs .items ():
955- for aggfunc in [value ] if isinstance (value , str ) else value :
956- data_col = "('{0}', '{1}')" .format (key , aggfunc ) if multi_aggs else key
957- data_columns .append (data_col )
958- column_index .append ((key , aggfunc ))
959- if aggfunc == "nunique" :
960- reordered .append (F .expr ('count(DISTINCT `{0}`) as `{1}`' .format (key , data_col )))
961- else :
962- reordered .append (F .expr ('{1}(`{0}`) as `{2}`' .format (key , aggfunc , data_col )))
963- sdf = sdf .groupby ().agg (* reordered )
964- internal = _InternalFrame (sdf = sdf ,
965- data_columns = data_columns ,
966- column_index = column_index if multi_aggs else None )
948+ kdf = DataFrame (GroupBy ._spark_groupby (self , func , ())) # type: DataFrame
967949
968- kdf = DataFrame (internal )
950+ # The codes below basically converts:
951+ #
952+ # A B
953+ # sum min min max
954+ # 0 12.0 1.0 2.0 8.0
955+ #
956+ # to:
957+ # A B
958+ # max NaN 8.0
959+ # min 1.0 2.0
960+ # sum 12.0 NaN
961+ #
962+ # Aggregated output is usually pretty much small. So it is fine to directly use pandas API.
969963 pdf = kdf .to_pandas ().transpose ().reset_index ()
970964 pdf = pdf .groupby (['level_1' ]).apply (
971965 lambda gpdf : gpdf .drop ('level_1' , 1 ).set_index ('level_0' ).transpose ()
972966 ).reset_index (level = 1 )
973967 pdf = pdf .drop (columns = 'level_1' )
974968 pdf .columns .names = [None ]
975969 pdf .index .names = [None ]
976- return DataFrame (pdf [func_or_funcs .keys ()])
970+
971+ return DataFrame (pdf [list (func .keys ())])
977972
978973 agg = aggregate
979974
0 commit comments