Skip to content

Commit 88c4310

Browse files
committed
Cleanup
1 parent 68e6863 commit 88c4310

File tree

2 files changed

+52
-49
lines changed

2 files changed

+52
-49
lines changed

databricks/koalas/frame.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial, reduce
2727
import sys
2828
from itertools import zip_longest
29-
from typing import Any, Optional, List, Tuple, Union, Generic, TypeVar, Iterable
29+
from typing import Any, Optional, List, Tuple, Union, Generic, TypeVar, Iterable, Dict
3030

3131
import numpy as np
3232
import pandas as pd
@@ -46,7 +46,6 @@
4646
from pyspark.sql.functions import pandas_udf, PandasUDFType
4747

4848
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
49-
from databricks.koalas.config import get_option
5049
from databricks.koalas.utils import validate_arguments_and_invoke_function, align_diff_frames
5150
from databricks.koalas.generic import _Frame
5251
from databricks.koalas.internal import _InternalFrame, IndexMap
@@ -874,7 +873,7 @@ def applymap(self, func):
874873

875874
# TODO: Series support is not implemented yet.
876875
# TODO: not all arguments are implemented comparing to Pandas' for now.
877-
def aggregate(self, func_or_funcs):
876+
def aggregate(self, func: Union[List[str], Dict[str, List[str]]]):
878877
"""Aggregate using one or more operations over the specified axis.
879878
880879
Parameters
@@ -889,12 +888,6 @@ def aggregate(self, func_or_funcs):
889888
-------
890889
DataFrame
891890
892-
The return can be:
893-
894-
* DataFrame : when DataFrame.agg is called with several functions
895-
896-
Return a DataFrame.
897-
898891
Notes
899892
-----
900893
`agg` is an alias for `aggregate`. Use the alias.
@@ -934,46 +927,48 @@ def aggregate(self, func_or_funcs):
934927
min 1.0 2.0
935928
sum 12.0 NaN
936929
"""
937-
if isinstance(func_or_funcs, list):
938-
func_or_funcs = dict([
939-
(column, func_or_funcs) for column in self.columns])
930+
from databricks.koalas.groupby import GroupBy
940931

941-
if not isinstance(func_or_funcs, dict) or \
932+
if isinstance(func, list):
933+
if all((isinstance(f, str) for f in func)):
934+
func = dict([
935+
(column, func) for column in self.columns])
936+
else:
937+
raise ValueError("If the given function is a list, it "
938+
"should only contains function names as strings.")
939+
940+
if not isinstance(func, dict) or \
942941
not all(isinstance(key, str) and
943942
(isinstance(value, str) or
944943
isinstance(value, list) and all(isinstance(v, str) for v in value))
945-
for key, value in func_or_funcs.items()):
944+
for key, value in func.items()):
946945
raise ValueError("aggs must be a dict mapping from column name (string) to aggregate "
947-
"functions (string or list of strings).")
946+
"functions (list of strings).")
948947

949-
sdf = self._sdf
950-
multi_aggs = any(isinstance(v, list) for v in func_or_funcs.values())
951-
reordered = []
952-
data_columns = []
953-
column_index = []
954-
for key, value in func_or_funcs.items():
955-
for aggfunc in [value] if isinstance(value, str) else value:
956-
data_col = "('{0}', '{1}')".format(key, aggfunc) if multi_aggs else key
957-
data_columns.append(data_col)
958-
column_index.append((key, aggfunc))
959-
if aggfunc == "nunique":
960-
reordered.append(F.expr('count(DISTINCT `{0}`) as `{1}`'.format(key, data_col)))
961-
else:
962-
reordered.append(F.expr('{1}(`{0}`) as `{2}`'.format(key, aggfunc, data_col)))
963-
sdf = sdf.groupby().agg(*reordered)
964-
internal = _InternalFrame(sdf=sdf,
965-
data_columns=data_columns,
966-
column_index=column_index if multi_aggs else None)
948+
kdf = DataFrame(GroupBy._spark_groupby(self, func, ())) # type: DataFrame
967949

968-
kdf = DataFrame(internal)
950+
# The codes below basically converts:
951+
#
952+
# A B
953+
# sum min min max
954+
# 0 12.0 1.0 2.0 8.0
955+
#
956+
# to:
957+
# A B
958+
# max NaN 8.0
959+
# min 1.0 2.0
960+
# sum 12.0 NaN
961+
#
962+
# Aggregated output is usually pretty much small. So it is fine to directly use pandas API.
969963
pdf = kdf.to_pandas().transpose().reset_index()
970964
pdf = pdf.groupby(['level_1']).apply(
971965
lambda gpdf: gpdf.drop('level_1', 1).set_index('level_0').transpose()
972966
).reset_index(level=1)
973967
pdf = pdf.drop(columns='level_1')
974968
pdf.columns.names = [None]
975969
pdf.index.names = [None]
976-
return DataFrame(pdf[func_or_funcs.keys()])
970+
971+
return DataFrame(pdf[list(func.keys())])
977972

978973
agg = aggregate
979974

databricks/koalas/groupby.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,23 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
120120
raise ValueError("aggs must be a dict mapping from column name (string) to aggregate "
121121
"functions (string or list of strings).")
122122

123-
sdf = self._kdf._sdf
124-
groupkeys = self._groupkeys
123+
kdf = DataFrame(GroupBy._spark_groupby(self._kdf, func_or_funcs, self._groupkeys))
124+
if not self._as_index:
125+
kdf = kdf.reset_index()
126+
return kdf
127+
128+
agg = aggregate
129+
130+
@staticmethod
131+
def _spark_groupby(kdf, func, groupkeys):
132+
sdf = kdf._sdf
125133
groupkey_cols = [s._scol.alias('__index_level_{}__'.format(i))
126134
for i, s in enumerate(groupkeys)]
127-
multi_aggs = any(isinstance(v, list) for v in func_or_funcs.values())
135+
multi_aggs = any(isinstance(v, list) for v in func.values())
128136
reordered = []
129137
data_columns = []
130138
column_index = []
131-
for key, value in func_or_funcs.items():
139+
for key, value in func.items():
132140
for aggfunc in [value] if isinstance(value, str) else value:
133141
data_col = "('{0}', '{1}')".format(key, aggfunc) if multi_aggs else key
134142
data_columns.append(data_col)
@@ -138,18 +146,18 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
138146
else:
139147
reordered.append(F.expr('{1}(`{0}`) as `{2}`'.format(key, aggfunc, data_col)))
140148
sdf = sdf.groupby(*groupkey_cols).agg(*reordered)
141-
internal = _InternalFrame(sdf=sdf,
149+
if len(groupkeys) > 0:
150+
index_map = [('__index_level_{}__'.format(i),
151+
s._internal.column_index[0])
152+
for i, s in enumerate(groupkeys)]
153+
return _InternalFrame(sdf=sdf,
142154
data_columns=data_columns,
143155
column_index=column_index if multi_aggs else None,
144-
index_map=[('__index_level_{}__'.format(i),
145-
s._internal.column_index[0])
146-
for i, s in enumerate(groupkeys)])
147-
kdf = DataFrame(internal)
148-
if not self._as_index:
149-
kdf = kdf.reset_index()
150-
return kdf
151-
152-
agg = aggregate
156+
index_map=index_map)
157+
else:
158+
return _InternalFrame(sdf=sdf,
159+
data_columns=data_columns,
160+
column_index=column_index if multi_aggs else None)
153161

154162
def count(self):
155163
"""

0 commit comments

Comments
 (0)