Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion databricks/koalas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from pyspark.sql import functions as F, Window
from pyspark.sql.types import FloatType, DoubleType, NumericType, StructField, StructType
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.functions import PandasUDFType, pandas_udf, Column

from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.typedef import _infer_return_type, as_spark_type
Expand Down Expand Up @@ -1008,6 +1008,120 @@ def rank(self, method='average', ascending=True):
"""
return self._rank(method, ascending)

# TODO: add axis parameter
def idxmax(self, skipna=True):
"""
Return index of first occurrence of maximum over requested axis in group.
NA/null values are excluded.

Parameters
----------
skipna : boolean, default True
Exclude NA/null values. If an entire row/column is NA, the result will be NA.

See Also
--------
Series.idxmax
DataFrame.idxmax
databricks.koalas.Series.groupby
databricks.koalas.DataFrame.groupby

Examples
--------
>>> df = ks.DataFrame({'a': [1, 1, 2, 2, 3],
... 'b': [1, 2, 3, 4, 5],
... 'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])

>>> df.groupby(['a']).idxmax().sort_index() # doctest: +NORMALIZE_WHITESPACE
b c
a
1 1 0
2 3 2
3 4 4
"""
if len(self._kdf._internal.index_names) != 1:
raise ValueError('idxmax only support one-level index now')
groupkeys = self._groupkeys
groupkey_cols = [s._scol.alias('__index_level_{}__'.format(i))
for i, s in enumerate(groupkeys)]
sdf = self._kdf._sdf
index = self._kdf._internal.index_columns[0]

stat_exprs = []
for ks in self._agg_columns:
if skipna:
order_column = Column(ks._scol._jc.desc_nulls_last())
else:
order_column = Column(ks._scol._jc.desc_nulls_first())
window = Window.partitionBy(groupkey_cols).orderBy(order_column)
sdf = sdf.withColumn(ks.name, F.when(F.row_number().over(window) == 1, F.col(index))
.otherwise(None))
stat_exprs.append(F.max(F.col(ks.name)).alias(ks.name))
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
internal = _InternalFrame(sdf=sdf,
data_columns=[ks.name for ks in self._agg_columns],
index_map=[('__index_level_{}__'.format(i), s.name)
for i, s in enumerate(groupkeys)])
kdf = DataFrame(internal)
return kdf

# TODO: add axis parameter
def idxmin(self, skipna=True):
"""
Return index of first occurrence of minimum over requested axis in group.
NA/null values are excluded.

Parameters
----------
skipna : boolean, default True
Exclude NA/null values. If an entire row/column is NA, the result will be NA.

See Also
--------
Series.idxmin
DataFrame.idxmin
databricks.koalas.Series.groupby
databricks.koalas.DataFrame.groupby

Examples
--------
>>> df = ks.DataFrame({'a': [1, 1, 2, 2, 3],
... 'b': [1, 2, 3, 4, 5],
... 'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])

>>> df.groupby(['a']).idxmin().sort_index() # doctest: +NORMALIZE_WHITESPACE
b c
a
1 0 1
2 2 3
3 4 4
"""
if len(self._kdf._internal.index_names) != 1:
raise ValueError('idxmin only support one-level index now')
groupkeys = self._groupkeys
groupkey_cols = [s._scol.alias('__index_level_{}__'.format(i))
for i, s in enumerate(groupkeys)]
sdf = self._kdf._sdf
index = self._kdf._internal.index_columns[0]

stat_exprs = []
for ks in self._agg_columns:
if skipna:
order_column = Column(ks._scol._jc.asc_nulls_last())
else:
order_column = Column(ks._scol._jc.asc_nulls_first())
window = Window.partitionBy(groupkey_cols).orderBy(order_column)
sdf = sdf.withColumn(ks.name, F.when(F.row_number().over(window) == 1, F.col(index))
.otherwise(None))
stat_exprs.append(F.max(F.col(ks.name)).alias(ks.name))
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
internal = _InternalFrame(sdf=sdf,
data_columns=[ks.name for ks in self._agg_columns],
index_map=[('__index_level_{}__'.format(i), s.name)
for i, s in enumerate(groupkeys)])
kdf = DataFrame(internal)
return kdf

# TODO: Series support is not implemented yet.
def transform(self, func):
"""
Expand Down
4 changes: 0 additions & 4 deletions databricks/koalas/missing/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class _MissingPandasLikeDataFrameGroupBy(object):
fillna = unsupported_property('fillna')
groups = unsupported_property('groups')
hist = unsupported_property('hist')
idxmax = unsupported_property('idxmax')
idxmin = unsupported_property('idxmin')
indices = unsupported_property('indices')
mad = unsupported_property('mad')
ngroups = unsupported_property('ngroups')
Expand Down Expand Up @@ -85,8 +83,6 @@ class _MissingPandasLikeSeriesGroupBy(object):
fillna = unsupported_property('fillna')
groups = unsupported_property('groups')
hist = unsupported_property('hist')
idxmax = unsupported_property('idxmax')
idxmin = unsupported_property('idxmin')
indices = unsupported_property('indices')
is_monotonic_decreasing = unsupported_property('is_monotonic_decreasing')
is_monotonic_increasing = unsupported_property('is_monotonic_increasing')
Expand Down
20 changes: 20 additions & 0 deletions databricks/koalas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,26 @@ def test_filter(self):
self.assert_eq(kdf.groupby(['a', 'b']).filter(lambda x: any(x.a == 2)).sort_index(),
pdf.groupby(['a', 'b']).filter(lambda x: any(x.a == 2)).sort_index())

def test_idxmax(self):
pdf = pd.DataFrame({'a': [1, 1, 2, 2, 3],
'b': [1, 2, 3, 4, 5],
'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])
kdf = koalas.DataFrame(pdf)
self.assert_eq(pdf.groupby(['a']).idxmax(),
kdf.groupby(['a']).idxmax().sort_index())
with self.assertRaisesRegex(ValueError, 'idxmax only support one-level index now'):
kdf.set_index(['a', 'b']).groupby(['c']).idxmax()

def test_idxmin(self):
pdf = pd.DataFrame({'a': [1, 1, 2, 2, 3],
'b': [1, 2, 3, 4, 5],
'c': [5, 4, 3, 2, 1]}, columns=['a', 'b', 'c'])
kdf = koalas.DataFrame(pdf)
self.assert_eq(pdf.groupby(['a']).idxmin(),
kdf.groupby(['a']).idxmin().sort_index())
with self.assertRaisesRegex(ValueError, 'idxmin only support one-level index now'):
kdf.set_index(['a', 'b']).groupby(['c']).idxmin()

def test_missing(self):
kdf = koalas.DataFrame({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9]})

Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ Computations / Descriptive Stats
GroupBy.var
GroupBy.size
GroupBy.diff
GroupBy.idxmax
GroupBy.idxmin