Skip to content

Commit 5d9e37d

Browse files
authored
REF: Move groupby.agg to apply (#39311)
1 parent cd49372 commit 5d9e37d

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ repos:
135135
files: ^pandas/tests/
136136
- id: FrameOrSeriesUnion
137137
name: Check for use of Union[Series, DataFrame] instead of FrameOrSeriesUnion alias
138-
entry: Union\[.*(Series.*DataFrame|DataFrame.*Series).*\]
138+
entry: Union\[.*(Series,.*DataFrame|DataFrame,.*Series).*\]
139139
language: pygrep
140140
types: [python]
141141
exclude: ^pandas/_typing\.py$

pandas/core/apply.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22

33
import abc
44
import inspect
5-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type, cast
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Dict,
9+
Iterator,
10+
List,
11+
Optional,
12+
Tuple,
13+
Type,
14+
Union,
15+
cast,
16+
)
617

718
import numpy as np
819

@@ -13,6 +24,7 @@
1324
AggFuncType,
1425
AggFuncTypeBase,
1526
AggFuncTypeDict,
27+
AggObjType,
1628
Axis,
1729
FrameOrSeriesUnion,
1830
)
@@ -34,6 +46,7 @@
3446

3547
if TYPE_CHECKING:
3648
from pandas import DataFrame, Index, Series
49+
from pandas.core.groupby import DataFrameGroupBy, SeriesGroupBy
3750

3851
ResType = Dict[int, Any]
3952

@@ -86,7 +99,7 @@ class Apply(metaclass=abc.ABCMeta):
8699

87100
def __init__(
88101
self,
89-
obj: FrameOrSeriesUnion,
102+
obj: AggObjType,
90103
func,
91104
raw: bool,
92105
result_type: Optional[str],
@@ -646,3 +659,28 @@ def apply_standard(self) -> FrameOrSeriesUnion:
646659
return obj._constructor(mapped, index=obj.index).__finalize__(
647660
obj, method="apply"
648661
)
662+
663+
664+
class GroupByApply(Apply):
665+
obj: Union[SeriesGroupBy, DataFrameGroupBy]
666+
667+
def __init__(
668+
self,
669+
obj: Union[SeriesGroupBy, DataFrameGroupBy],
670+
func: AggFuncType,
671+
args,
672+
kwds,
673+
):
674+
kwds = kwds.copy()
675+
self.axis = obj.obj._get_axis_number(kwds.get("axis", 0))
676+
super().__init__(
677+
obj,
678+
func,
679+
raw=False,
680+
result_type=None,
681+
args=args,
682+
kwds=kwds,
683+
)
684+
685+
def apply(self):
686+
raise NotImplementedError

pandas/core/groupby/generic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@
5757
from pandas.core import algorithms, nanops
5858
from pandas.core.aggregation import (
5959
agg_list_like,
60-
aggregate,
6160
maybe_mangle_lambdas,
6261
reconstruct_func,
6362
validate_func_kwargs,
6463
)
64+
from pandas.core.apply import GroupByApply
6565
from pandas.core.arrays import Categorical, ExtensionArray
6666
from pandas.core.base import DataError, SpecificationError
6767
import pandas.core.common as com
@@ -952,7 +952,8 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
952952
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
953953
func = maybe_mangle_lambdas(func)
954954

955-
result, how = aggregate(self, func, *args, **kwargs)
955+
op = GroupByApply(self, func, args, kwargs)
956+
result, how = op.agg()
956957
if how is None:
957958
return result
958959

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ def test_aggregate_str_func(tsframe, groupbyfunc):
203203
tm.assert_frame_equal(result, expected)
204204

205205

206+
def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
207+
gb = df.groupby(level=0)
208+
if reduction_func in ("idxmax", "idxmin"):
209+
error = TypeError
210+
msg = "reduction operation '.*' not allowed for this dtype"
211+
else:
212+
error = ValueError
213+
msg = f"Operation {reduction_func} does not support axis=1"
214+
with pytest.raises(error, match=msg):
215+
gb.agg(reduction_func, axis=1)
216+
217+
206218
def test_aggregate_item_by_item(df):
207219
grouped = df.groupby("A")
208220

0 commit comments

Comments
 (0)