Skip to content

Commit 17a1ed9

Browse files
feat: Can call agg with some callables (#2055)
1 parent 36ee4d1 commit 17a1ed9

File tree

7 files changed

+106
-70
lines changed

7 files changed

+106
-70
lines changed

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -461,23 +461,19 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
461461

462462
def agg(self, func=None, **kwargs) -> typing.Union[df.DataFrame, series.Series]:
463463
if func:
464-
if isinstance(func, str):
465-
return self.size() if func == "size" else self._agg_string(func)
466-
elif utils.is_dict_like(func):
464+
if utils.is_dict_like(func):
467465
return self._agg_dict(func)
468466
elif utils.is_list_like(func):
469467
return self._agg_list(func)
470468
else:
471-
raise NotImplementedError(
472-
f"Aggregate with {func} not supported. {constants.FEEDBACK_LINK}"
473-
)
469+
return self.size() if func == "size" else self._agg_func(func)
474470
else:
475471
return self._agg_named(**kwargs)
476472

477-
def _agg_string(self, func: str) -> df.DataFrame:
473+
def _agg_func(self, func) -> df.DataFrame:
478474
ids, labels = self._aggregated_columns()
479475
aggregations = [
480-
aggs.agg(col_id, agg_ops.lookup_agg_func(func)) for col_id in ids
476+
aggs.agg(col_id, agg_ops.lookup_agg_func(func)[0]) for col_id in ids
481477
]
482478
agg_block, _ = self._block.aggregate(
483479
by_column_ids=self._by_col_ids,
@@ -500,7 +496,7 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
500496
funcs_for_id if utils.is_list_like(funcs_for_id) else [funcs_for_id]
501497
)
502498
for f in func_list:
503-
aggregations.append(aggs.agg(col_id, agg_ops.lookup_agg_func(f)))
499+
aggregations.append(aggs.agg(col_id, agg_ops.lookup_agg_func(f)[0]))
504500
column_labels.append(label)
505501
agg_block, _ = self._block.aggregate(
506502
by_column_ids=self._by_col_ids,
@@ -525,19 +521,23 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
525521
def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
526522
ids, labels = self._aggregated_columns()
527523
aggregations = [
528-
aggs.agg(col_id, agg_ops.lookup_agg_func(f)) for col_id in ids for f in func
524+
aggs.agg(col_id, agg_ops.lookup_agg_func(f)[0])
525+
for col_id in ids
526+
for f in func
529527
]
530528

531529
if self._block.column_labels.nlevels > 1:
532530
# Restructure MultiIndex for proper format: (idx1, idx2, func)
533531
# rather than ((idx1, idx2), func).
534532
column_labels = [
535-
tuple(label) + (f,)
533+
tuple(label) + (agg_ops.lookup_agg_func(f)[1],)
536534
for label in labels.to_frame(index=False).to_numpy()
537535
for f in func
538536
]
539537
else: # Single-level index
540-
column_labels = [(label, f) for label in labels for f in func]
538+
column_labels = [
539+
(label, agg_ops.lookup_agg_func(f)[1]) for label in labels for f in func
540+
]
541541

542542
agg_block, _ = self._block.aggregate(
543543
by_column_ids=self._by_col_ids,
@@ -563,7 +563,7 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
563563
if not isinstance(v, tuple) or (len(v) != 2):
564564
raise TypeError("kwargs values must be 2-tuples of column, aggfunc")
565565
col_id = self._resolve_label(v[0])
566-
aggregations.append(aggs.agg(col_id, agg_ops.lookup_agg_func(v[1])))
566+
aggregations.append(aggs.agg(col_id, agg_ops.lookup_agg_func(v[1])[0]))
567567
column_labels.append(k)
568568
agg_block, _ = self._block.aggregate(
569569
by_column_ids=self._by_col_ids,

bigframes/core/groupby/series_group_by.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,17 @@ def prod(self, *args) -> series.Series:
216216

217217
def agg(self, func=None) -> typing.Union[df.DataFrame, series.Series]:
218218
column_names: list[str] = []
219-
if isinstance(func, str):
220-
aggregations = [aggs.agg(self._value_column, agg_ops.lookup_agg_func(func))]
221-
column_names = [func]
222-
elif utils.is_list_like(func):
223-
aggregations = [
224-
aggs.agg(self._value_column, agg_ops.lookup_agg_func(f)) for f in func
225-
]
226-
column_names = list(func)
227-
else:
219+
if utils.is_dict_like(func):
228220
raise NotImplementedError(
229221
f"Aggregate with {func} not supported. {constants.FEEDBACK_LINK}"
230222
)
223+
if not utils.is_list_like(func):
224+
func = [func]
225+
226+
aggregations = [
227+
aggs.agg(self._value_column, agg_ops.lookup_agg_func(f)[0]) for f in func
228+
]
229+
column_names = [agg_ops.lookup_agg_func(f)[1] for f in func]
231230

232231
agg_block, _ = self._block.aggregate(
233232
by_column_ids=self._by_col_ids,

bigframes/dataframe.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,12 +3172,7 @@ def nunique(self) -> bigframes.series.Series:
31723172
block = self._block.aggregate_all_and_stack(agg_ops.nunique_op)
31733173
return bigframes.series.Series(block)
31743174

3175-
def agg(
3176-
self,
3177-
func: str
3178-
| typing.Sequence[str]
3179-
| typing.Mapping[blocks.Label, typing.Sequence[str] | str],
3180-
) -> DataFrame | bigframes.series.Series:
3175+
def agg(self, func) -> DataFrame | bigframes.series.Series:
31813176
if utils.is_dict_like(func):
31823177
# Must check dict-like first because dictionaries are list-like
31833178
# according to Pandas.
@@ -3191,15 +3186,17 @@ def agg(
31913186
if col_id is None:
31923187
raise KeyError(f"Column {col_label} does not exist")
31933188
for agg_func in agg_func_list:
3194-
agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func))
3189+
op_and_label = agg_ops.lookup_agg_func(agg_func)
31953190
agg_expr = (
3196-
agg_expressions.UnaryAggregation(agg_op, ex.deref(col_id))
3197-
if isinstance(agg_op, agg_ops.UnaryAggregateOp)
3198-
else agg_expressions.NullaryAggregation(agg_op)
3191+
agg_expressions.UnaryAggregation(
3192+
op_and_label[0], ex.deref(col_id)
3193+
)
3194+
if isinstance(op_and_label[0], agg_ops.UnaryAggregateOp)
3195+
else agg_expressions.NullaryAggregation(op_and_label[0])
31993196
)
32003197
aggs.append(agg_expr)
32013198
labels.append(col_label)
3202-
funcnames.append(agg_func)
3199+
funcnames.append(op_and_label[1])
32033200

32043201
# if any list in dict values, format output differently
32053202
if any(utils.is_list_like(v) for v in func.values()):
@@ -3220,7 +3217,7 @@ def agg(
32203217
)
32213218
)
32223219
elif utils.is_list_like(func):
3223-
aggregations = [agg_ops.lookup_agg_func(f) for f in func]
3220+
aggregations = [agg_ops.lookup_agg_func(f)[0] for f in func]
32243221

32253222
for dtype, agg in itertools.product(self.dtypes, aggregations):
32263223
agg.output_type(
@@ -3236,9 +3233,7 @@ def agg(
32363233

32373234
else: # function name string
32383235
return bigframes.series.Series(
3239-
self._block.aggregate_all_and_stack(
3240-
agg_ops.lookup_agg_func(typing.cast(str, func))
3241-
)
3236+
self._block.aggregate_all_and_stack(agg_ops.lookup_agg_func(func)[0])
32423237
)
32433238

32443239
aggregate = agg

bigframes/operations/aggregations.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import abc
1818
import dataclasses
1919
import typing
20-
from typing import ClassVar, Iterable, Optional, TYPE_CHECKING
20+
from typing import Callable, ClassVar, Iterable, Optional, TYPE_CHECKING
2121

22+
import numpy as np
2223
import pandas as pd
2324
import pyarrow as pa
2425

@@ -678,7 +679,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
678679

679680

680681
# TODO: Alternative names and lookup from numpy function objects
681-
_AGGREGATIONS_LOOKUP: typing.Dict[
682+
_STRING_TO_AGG_OP: typing.Dict[
682683
str, typing.Union[UnaryAggregateOp, NullaryAggregateOp]
683684
] = {
684685
op.name: op
@@ -705,17 +706,32 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
705706
]
706707
}
707708

709+
_CALLABLE_TO_AGG_OP: typing.Dict[
710+
Callable, typing.Union[UnaryAggregateOp, NullaryAggregateOp]
711+
] = {
712+
np.sum: sum_op,
713+
np.mean: mean_op,
714+
np.median: median_op,
715+
np.prod: product_op,
716+
np.max: max_op,
717+
np.min: min_op,
718+
np.std: std_op,
719+
np.var: var_op,
720+
np.all: all_op,
721+
np.any: any_op,
722+
np.unique: nunique_op,
723+
# TODO(b/443252872): Solve
724+
# list: ArrayAggOp(),
725+
np.size: size_op,
726+
}
708727

709-
def lookup_agg_func(key: str) -> typing.Union[UnaryAggregateOp, NullaryAggregateOp]:
710-
if callable(key):
711-
raise NotImplementedError(
712-
"Aggregating with callable object not supported, pass method name as string instead (eg. 'sum' instead of np.sum)."
713-
)
714-
if not isinstance(key, str):
715-
raise ValueError(
716-
f"Cannot aggregate using object of type: {type(key)}. Use string method name (eg. 'sum')"
717-
)
718-
if key in _AGGREGATIONS_LOOKUP:
719-
return _AGGREGATIONS_LOOKUP[key]
728+
729+
def lookup_agg_func(
730+
key,
731+
) -> tuple[typing.Union[UnaryAggregateOp, NullaryAggregateOp], str]:
732+
if key in _STRING_TO_AGG_OP:
733+
return (_STRING_TO_AGG_OP[key], key)
734+
if key in _CALLABLE_TO_AGG_OP:
735+
return (_CALLABLE_TO_AGG_OP[key], key.__name__)
720736
else:
721737
raise ValueError(f"Unrecognize aggregate function: {key}")

bigframes/series.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,17 +1330,15 @@ def agg(self, func: str | typing.Sequence[str]) -> scalars.Scalar | Series:
13301330
raise NotImplementedError(
13311331
f"Multiple aggregations only supported on numeric series. {constants.FEEDBACK_LINK}"
13321332
)
1333-
aggregations = [agg_ops.lookup_agg_func(f) for f in func]
1333+
aggregations = [agg_ops.lookup_agg_func(f)[0] for f in func]
13341334
return Series(
13351335
self._block.summarize(
13361336
[self._value_column],
13371337
aggregations,
13381338
)
13391339
)
13401340
else:
1341-
return self._apply_aggregation(
1342-
agg_ops.lookup_agg_func(typing.cast(str, func))
1343-
)
1341+
return self._apply_aggregation(agg_ops.lookup_agg_func(func)[0])
13441342

13451343
aggregate = agg
13461344
aggregate.__doc__ = inspect.getdoc(vendored_pandas_series.Series.agg)

tests/system/small/test_dataframe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6016,7 +6016,7 @@ def test_astype_invalid_type_fail(scalars_dfs):
60166016
bf_df.astype(123)
60176017

60186018

6019-
def test_agg_with_dict_lists(scalars_dfs):
6019+
def test_agg_with_dict_lists_strings(scalars_dfs):
60206020
bf_df, pd_df = scalars_dfs
60216021
agg_funcs = {
60226022
"int64_too": ["min", "max"],
@@ -6031,6 +6031,21 @@ def test_agg_with_dict_lists(scalars_dfs):
60316031
)
60326032

60336033

6034+
def test_agg_with_dict_lists_callables(scalars_dfs):
6035+
bf_df, pd_df = scalars_dfs
6036+
agg_funcs = {
6037+
"int64_too": [np.min, np.max],
6038+
"int64_col": [np.min, np.var],
6039+
}
6040+
6041+
bf_result = bf_df.agg(agg_funcs).to_pandas()
6042+
pd_result = pd_df.agg(agg_funcs)
6043+
6044+
pd.testing.assert_frame_equal(
6045+
bf_result, pd_result, check_dtype=False, check_index_type=False
6046+
)
6047+
6048+
60346049
def test_agg_with_dict_list_and_str(scalars_dfs):
60356050
bf_df, pd_df = scalars_dfs
60366051
agg_funcs = {

0 commit comments

Comments
 (0)