Skip to content

Commit 7c46bdd

Browse files
authored
FEAT-#2375: implementation of multi-column groupby aggregation (#2461)
Signed-off-by: Dmitry Chigarev <[email protected]>
1 parent d710a16 commit 7c46bdd

File tree

12 files changed

+353
-127
lines changed

12 files changed

+353
-127
lines changed

modin/backends/base/query_compiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,7 @@ def groupby_size(
13941394
reduce_args=reduce_args,
13951395
numeric_only=numeric_only,
13961396
drop=drop,
1397+
method="size",
13971398
)
13981399

13991400
def groupby_agg(
@@ -1407,13 +1408,10 @@ def groupby_agg(
14071408
groupby_kwargs,
14081409
drop=False,
14091410
):
1410-
if is_multi_by:
1411-
if isinstance(by, type(self)) and len(by.columns) == 1:
1412-
by = by.columns[0] if drop else by.to_pandas().squeeze()
1413-
elif isinstance(by, type(self)):
1414-
by = list(by.columns)
1415-
else:
1416-
by = by.to_pandas().squeeze() if isinstance(by, type(self)) else by
1411+
if isinstance(by, type(self)) and len(by.columns) == 1:
1412+
by = by.columns[0] if drop else by.to_pandas().squeeze()
1413+
elif isinstance(by, type(self)):
1414+
by = list(by.columns)
14171415

14181416
return GroupByDefault.register(pandas.core.groupby.DataFrameGroupBy.aggregate)(
14191417
self,

modin/backends/pandas/query_compiler.py

Lines changed: 135 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,34 +2525,74 @@ def groupby_agg(
25252525
if callable(agg_func):
25262526
agg_func = wrap_udf_function(agg_func)
25272527

2528-
if is_multi_by:
2529-
return super().groupby_agg(
2530-
by=by,
2531-
is_multi_by=is_multi_by,
2532-
axis=axis,
2533-
agg_func=agg_func,
2534-
agg_args=agg_args,
2535-
agg_kwargs=agg_kwargs,
2536-
groupby_kwargs=groupby_kwargs,
2537-
drop=drop,
2538-
)
2539-
2540-
by = by.to_pandas().squeeze() if isinstance(by, type(self)) else by
2541-
25422528
# since we're going to modify `groupby_kwargs` dict in a `groupby_agg_builder`,
25432529
# we want to copy it to not propagate these changes into source dict, in case
25442530
# of unsuccessful end of function
25452531
groupby_kwargs = groupby_kwargs.copy()
25462532

25472533
as_index = groupby_kwargs.get("as_index", True)
2534+
if isinstance(by, type(self)):
2535+
# `drop` parameter indicates whether or not 'by' data came
2536+
# from the `self` frame:
2537+
# True: 'by' data came from the `self`
2538+
# False: external 'by' data
2539+
if drop:
2540+
internal_by = by.columns
2541+
by = [by]
2542+
else:
2543+
internal_by = []
2544+
by = [by]
2545+
else:
2546+
if not isinstance(by, list):
2547+
by = [by]
2548+
internal_by = [o for o in by if isinstance(o, str) and o in self.columns]
2549+
internal_qc = (
2550+
[self.getitem_column_array(internal_by)] if len(internal_by) else []
2551+
)
2552+
2553+
by = internal_qc + by[len(internal_by) :]
2554+
2555+
broadcastable_by = [o._modin_frame for o in by if isinstance(o, type(self))]
2556+
not_broadcastable_by = [o for o in by if not isinstance(o, type(self))]
25482557

2549-
def groupby_agg_builder(df):
2558+
def groupby_agg_builder(df, by=None, drop=False, partition_idx=None):
25502559
# Set `as_index` to True to track the metadata of the grouping object
25512560
# It is used to make sure that between phases we are constructing the
25522561
# right index and placing columns in the correct order.
25532562
groupby_kwargs["as_index"] = True
25542563

2555-
def compute_groupby(df):
2564+
internal_by_cols = pandas.Index([])
2565+
missmatched_cols = pandas.Index([])
2566+
if by is not None:
2567+
internal_by_df = by[internal_by]
2568+
2569+
if isinstance(internal_by_df, pandas.Series):
2570+
internal_by_df = internal_by_df.to_frame()
2571+
2572+
missmatched_cols = internal_by_df.columns.difference(df.columns)
2573+
df = pandas.concat(
2574+
[df, internal_by_df[missmatched_cols]],
2575+
axis=1,
2576+
copy=False,
2577+
)
2578+
internal_by_cols = internal_by_df.columns
2579+
2580+
external_by = by.columns.difference(internal_by)
2581+
external_by_df = by[external_by].squeeze(axis=1)
2582+
2583+
if isinstance(external_by_df, pandas.DataFrame):
2584+
external_by_cols = [o for _, o in external_by_df.iteritems()]
2585+
else:
2586+
external_by_cols = [external_by_df]
2587+
2588+
by = internal_by_cols.tolist() + external_by_cols
2589+
2590+
else:
2591+
by = []
2592+
2593+
by += not_broadcastable_by
2594+
2595+
def compute_groupby(df, drop=False, partition_idx=0):
25562596
grouped_df = df.groupby(by=by, axis=axis, **groupby_kwargs)
25572597
try:
25582598
if isinstance(agg_func, dict):
@@ -2569,17 +2609,91 @@ def compute_groupby(df):
25692609
# issues with extracting the index.
25702610
except (DataError, TypeError):
25712611
result = pandas.DataFrame(index=grouped_df.size().index)
2612+
if isinstance(result, pandas.Series):
2613+
result = result.to_frame(
2614+
result.name if result.name is not None else "__reduced__"
2615+
)
2616+
2617+
result_cols = result.columns
2618+
result.drop(columns=missmatched_cols, inplace=True, errors="ignore")
2619+
2620+
if not as_index:
2621+
keep_index_levels = len(by) > 1 and any(
2622+
isinstance(x, pandas.CategoricalDtype)
2623+
for x in df[internal_by_cols].dtypes
2624+
)
2625+
2626+
cols_to_insert = (
2627+
internal_by_cols.intersection(result_cols)
2628+
if keep_index_levels
2629+
else internal_by_cols.difference(result_cols)
2630+
)
2631+
2632+
if keep_index_levels:
2633+
result.drop(
2634+
columns=cols_to_insert, inplace=True, errors="ignore"
2635+
)
2636+
2637+
drop = True
2638+
if partition_idx == 0:
2639+
drop = False
2640+
if not keep_index_levels:
2641+
lvls_to_drop = [
2642+
i
2643+
for i, name in enumerate(result.index.names)
2644+
if name not in cols_to_insert
2645+
]
2646+
if len(lvls_to_drop) == result.index.nlevels:
2647+
drop = True
2648+
else:
2649+
result.index = result.index.droplevel(lvls_to_drop)
2650+
2651+
if (
2652+
not isinstance(result.index, pandas.MultiIndex)
2653+
and result.index.name is None
2654+
):
2655+
drop = True
2656+
2657+
result.reset_index(drop=drop, inplace=True)
2658+
2659+
new_index_names = [
2660+
None
2661+
if isinstance(name, str) and name.startswith("__reduced__")
2662+
else name
2663+
for name in result.index.names
2664+
]
2665+
2666+
cols_to_drop = (
2667+
result.columns[result.columns.str.match(r"__reduced__.*", na=False)]
2668+
if hasattr(result.columns, "str")
2669+
else []
2670+
)
2671+
2672+
result.index.names = new_index_names
2673+
2674+
# Not dropping columns if result is Series
2675+
if len(result.columns) > 1:
2676+
result.drop(columns=cols_to_drop, inplace=True)
2677+
25722678
return result
25732679

25742680
try:
2575-
return compute_groupby(df)
2681+
return compute_groupby(df, drop, partition_idx)
25762682
# This will happen with Arrow buffer read-only errors. We don't want to copy
25772683
# all the time, so this will try to fast-path the code first.
25782684
except (ValueError, KeyError):
2579-
return compute_groupby(df.copy())
2685+
return compute_groupby(df.copy(), drop, partition_idx)
25802686

2581-
new_modin_frame = self._modin_frame._apply_full_axis(
2582-
axis, lambda df: groupby_agg_builder(df)
2687+
apply_indices = list(agg_func.keys()) if isinstance(agg_func, dict) else None
2688+
2689+
new_modin_frame = self._modin_frame.broadcast_apply_full_axis(
2690+
axis=axis,
2691+
func=lambda df, by=None, partition_idx=None: groupby_agg_builder(
2692+
df, by, drop, partition_idx
2693+
),
2694+
other=broadcastable_by,
2695+
apply_indices=apply_indices,
2696+
enumerate_partitions=True,
25832697
)
25842698
result = self.__constructor__(new_modin_frame)
25852699

@@ -2598,14 +2712,7 @@ def compute_groupby(df):
25982712
except Exception as e:
25992713
raise type(e)("No numeric types to aggregate.")
26002714

2601-
# Reset `as_index` because it was edited inplace.
2602-
groupby_kwargs["as_index"] = as_index
2603-
if as_index:
2604-
return result
2605-
else:
2606-
if result.index.name is None or result.index.name in result.columns:
2607-
drop = False
2608-
return result.reset_index(drop=not drop)
2715+
return result
26092716

26102717
# END Manual Partitioning methods
26112718

modin/data_management/functions/default_methods/groupby_default.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class GroupBy:
2727
@classmethod
2828
def validate_by(cls, by):
2929
def try_cast_series(df):
30+
if isinstance(df, pandas.DataFrame):
31+
df = df.squeeze(axis=1)
3032
if not isinstance(df, pandas.Series):
3133
return df
3234
if df.name == "__reduced__":
@@ -73,11 +75,6 @@ def fn(
7375
):
7476
by = cls.validate_by(by)
7577

76-
if not is_multi_by:
77-
groupby_args = groupby_args.copy()
78-
as_index = groupby_args.pop("as_index", True)
79-
groupby_args["as_index"] = True
80-
8178
grp = df.groupby(by, axis=axis, **groupby_args)
8279
agg_func = cls.get_func(grp, key, **kwargs)
8380
result = (
@@ -86,15 +83,7 @@ def fn(
8683
else agg_func(grp, **agg_args)
8784
)
8885

89-
if not is_multi_by:
90-
if as_index:
91-
return result
92-
else:
93-
if result.index.name is None or result.index.name in result.columns:
94-
drop = False
95-
return result.reset_index(drop=not drop)
96-
else:
97-
return result
86+
return result
9887

9988
return fn
10089

@@ -111,6 +100,7 @@ def fn(
111100
**kwargs
112101
):
113102
if not isinstance(by, (pandas.Series, pandas.DataFrame)):
103+
by = cls.validate_by(by)
114104
return agg_func(
115105
df.groupby(by=by, axis=axis, **groupby_args), **map_args
116106
)
@@ -137,11 +127,16 @@ def fn(
137127
grp = df.groupby(by, axis=axis, **groupby_args)
138128
result = agg_func(grp, **map_args)
139129

130+
if isinstance(result, pandas.Series):
131+
result = result.to_frame()
132+
140133
if not as_index:
141134
if (
142135
len(result.index.names) == 1 and result.index.names[0] is None
143136
) or all([name in result.columns for name in result.index.names]):
144137
drop = False
138+
elif kwargs.get("method") == "size":
139+
drop = True
145140
result = result.reset_index(drop=not drop)
146141

147142
if result.index.name == "__reduced__":

modin/data_management/functions/groupby_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def caller(
3131
drop=False,
3232
):
3333
if not isinstance(by, (type(query_compiler), str)):
34-
by = try_cast_to_pandas(by)
34+
by = try_cast_to_pandas(by, squeeze=True)
3535
return query_compiler.default_to_pandas(
3636
lambda df: map_func(
3737
df.groupby(by=by, axis=axis, **groupby_args), **map_args

modin/engines/base/frame/axis_partition.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from abc import ABC
1515
import pandas
16+
import numpy as np
1617
from modin.data_management.utils import split_result_of_axis_func_pandas
1718

1819

@@ -146,12 +147,13 @@ def apply(
146147
if other_axis_partition is not None:
147148
if not isinstance(other_axis_partition, list):
148149
other_axis_partition = [other_axis_partition]
149-
other_shape = (
150-
len(other_axis_partition),
151-
len(other_axis_partition[0].list_of_blocks),
150+
151+
# (other_shape[i-1], other_shape[i]) will indicate slice
152+
# to restore i-1 axis partition
153+
other_shape = np.cumsum(
154+
[0] + [len(o.list_of_blocks) for o in other_axis_partition]
152155
)
153-
if not self.axis:
154-
other_shape = tuple(reversed(other_shape))
156+
155157
return self._wrap_partitions(
156158
self.deploy_func_between_two_axis_partitions(
157159
self.axis,
@@ -268,14 +270,14 @@ def deploy_func_between_two_axis_partitions(
268270

269271
rt_parts = partitions[len_of_left:]
270272

271-
# reshaping flattened `rt_parts` array into with shape `other_shape`
273+
# reshaping flattened `rt_parts` array into a frame with shape `other_shape`
272274
combined_axis = [
273275
pandas.concat(
274-
[rt_parts[other_shape[axis] * i + j] for j in range(other_shape[axis])],
276+
rt_parts[other_shape[i - 1] : other_shape[i]],
275277
axis=axis,
276278
copy=False,
277279
)
278-
for i in range(other_shape[axis ^ 1])
280+
for i in range(1, len(other_shape))
279281
]
280282
rt_frame = pandas.concat(combined_axis, axis=axis ^ 1, copy=False)
281283

0 commit comments

Comments
 (0)