diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6376dbefcf435..3e578b4878312 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -10,7 +10,7 @@ from functools import partial from textwrap import dedent import typing -from typing import Any, Callable, FrozenSet, Iterable, Sequence, Type, Union, cast +from typing import Any, Callable, Dict, FrozenSet, Iterable, Sequence, Type, Union, cast import warnings import numpy as np @@ -322,7 +322,9 @@ def _aggregate_multiple_funcs(self, arg, _level): return DataFrame(results, columns=columns) - def _wrap_series_output(self, output, index, names=None): + def _wrap_series_output( + self, output: dict, index: Index, names=None + ) -> Union[Series, DataFrame]: """ common agg/transform wrapping logic """ output = output[self._selection_name] @@ -334,18 +336,22 @@ def _wrap_series_output(self, output, index, names=None): name = self._selected_obj.name return Series(output, index=index, name=name) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output( + self, output: dict, names=None + ) -> Union[Series, DataFrame]: result = self._wrap_series_output( output=output, index=self.grouper.result_index, names=names ) return self._reindex_output(result)._convert(datetime=True) - def _wrap_transformed_output(self, output, names=None): + def _wrap_transformed_output( + self, output: dict, names=None + ) -> Union[Series, DataFrame]: return self._wrap_series_output( output=output, index=self.obj.index, names=names ) - def _wrap_applied_output(self, keys, values, not_indexed_same=False): + def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False): if len(keys) == 0: # GH #6265 return Series([], name=self._selection_name, index=keys) @@ -377,8 +383,8 @@ def _get_index() -> Index: result = Series(data=values, index=_get_index(), name=self._selection_name) return self._reindex_output(result) - def _aggregate_named(self, func, *args, **kwargs): - result = OrderedDict() + def _aggregate_named(self, func, *args, **kwargs) -> Dict: + result: Dict = {} for name, group in self: group.name = name @@ -443,18 +449,16 @@ def transform(self, func, *args, **kwargs): result.index = self._selected_obj.index return result - def _transform_fast(self, func, func_nm) -> Series: + def _transform_fast(self, func: Callable, func_nm: str) -> Series: """ fast version of transform, only applicable to builtin/cythonizable functions """ - if isinstance(func, str): - func = getattr(self, func) ids, _, ngroup = self.grouper.group_info - cast = self._transform_should_cast(func_nm) + should_cast = self._transform_should_cast(func_nm) out = algorithms.take_1d(func()._values, ids) - if cast: + if should_cast: out = self._try_cast(out, self.obj) return Series(out, index=self.obj.index, name=self.obj.name) @@ -579,7 +583,12 @@ def describe(self, **kwargs): return result.unstack() def value_counts( - self, normalize=False, sort=True, ascending=False, bins=None, dropna=True + self, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + bins=None, + dropna: bool = True, ): from pandas.core.reshape.tile import cut @@ -1069,6 +1078,7 @@ def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame: def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: # only for axis==0 + should_cast = self._transform_should_cast(func) obj = self._obj_with_exclusions result = OrderedDict() # type: dict cannot_agg = [] @@ -1077,9 +1087,8 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: data = obj[item] colg = SeriesGroupBy(data, selection=item, grouper=self.grouper) - cast = self._transform_should_cast(func) try: - result[item] = colg.aggregate(func, *args, **kwargs) + res = colg.aggregate(func, *args, **kwargs) except ValueError as err: if "Must produce aggregated value" in str(err): @@ -1091,8 +1100,9 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame: continue else: - if cast: - result[item] = self._try_cast(result[item], data) + if should_cast: + res = self._try_cast(res, data) + result[item] = res result_columns = obj.columns if cannot_agg: @@ -1115,7 +1125,7 @@ def _decide_output_index(self, output, labels): return output_keys - def _wrap_applied_output(self, keys, values, not_indexed_same=False): + def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False): if len(keys) == 0: return DataFrame(index=keys) @@ -1367,13 +1377,15 @@ def transform(self, func, *args, **kwargs): return self._transform_fast(result, obj, func) - def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame: + def _transform_fast( + self, result: DataFrame, obj: DataFrame, func_nm: str + ) -> DataFrame: """ Fast transform path for aggregations """ # if there were groups with no observations (Categorical only?) # try casting data to original dtype - cast = self._transform_should_cast(func_nm) + should_cast = self._transform_should_cast(func_nm) # for each col, reshape to to size of original frame # by take operation @@ -1381,7 +1393,7 @@ def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFra output = [] for i, _ in enumerate(result.columns): res = algorithms.take_1d(result.iloc[:, i].values, ids) - if cast: + if should_cast: res = self._try_cast(res, obj.iloc[:, i]) output.append(res) @@ -1579,7 +1591,7 @@ def _insert_inaxis_grouper_inplace(self, result): if in_axis: result.insert(0, name, lev) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output(self, output: dict, names=None): agg_axis = 0 if self.axis == 1 else 1 agg_labels = self._obj_with_exclusions._get_axis(agg_axis) @@ -1598,7 +1610,7 @@ def _wrap_aggregated_output(self, output, names=None): return self._reindex_output(result)._convert(datetime=True) - def _wrap_transformed_output(self, output, names=None) -> DataFrame: + def _wrap_transformed_output(self, output: dict, names=None) -> DataFrame: return DataFrame(output, index=self.obj.index) def _wrap_agged_blocks(self, items, blocks): diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 294cb723eee1a..7f31143b3016b 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -832,6 +832,7 @@ def _transform_should_cast(self, func_nm: str) -> bool: ) def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs): + should_cast = self._transform_should_cast(how) output = collections.OrderedDict() # type: dict for obj in self._iterate_slices(): name = obj.name @@ -843,20 +844,19 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs): result, names = self.grouper.transform(obj.values, how, **kwargs) except NotImplementedError: continue - if self._transform_should_cast(how): - output[name] = self._try_cast(result, obj) - else: - output[name] = result + if should_cast: + result = self._try_cast(result, obj) + output[name] = result if len(output) == 0: raise DataError("No numeric types to aggregate") return self._wrap_transformed_output(output, names) - def _wrap_aggregated_output(self, output, names=None): + def _wrap_aggregated_output(self, output: dict, names=None): raise AbstractMethodError(self) - def _wrap_transformed_output(self, output, names=None): + def _wrap_transformed_output(self, output: dict, names=None): raise AbstractMethodError(self) def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False):