diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 0bd6f746e4f3a..41a5195008f0c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -132,6 +132,9 @@ def pinner(cls): class SeriesGroupBy(GroupBy): _apply_whitelist = base.series_apply_whitelist + def _iterate_slices(self): + yield self._selection_name, self._selected_obj + @property def _selection_name(self): """ @@ -323,7 +326,7 @@ def _aggregate_multiple_funcs(self, arg, _level): return DataFrame(results, columns=columns) - def _wrap_output(self, output, index, names=None): + def _wrap_series_output(self, output, index, names=None): """ common agg/transform wrapping logic """ output = output[self._selection_name] @@ -336,13 +339,15 @@ def _wrap_output(self, output, index, names=None): return Series(output, index=index, name=name) def _wrap_aggregated_output(self, output, names=None): - result = self._wrap_output( + result = self._wrap_series_output( output=output, index=self.grouper.result_index, names=names ) return self._reindex_output(result)._convert(datetime=True) def _wrap_transformed_output(self, output, names=None): - return self._wrap_output(output=output, index=self.obj.index, names=names) + return self._wrap_series_output( + output=output, index=self.obj.index, names=names + ) def _wrap_applied_output(self, keys, values, not_indexed_same=False): if len(keys) == 0: @@ -866,7 +871,7 @@ def aggregate(self, func=None, *args, **kwargs): if self.grouper.nkeys > 1: return self._python_agg_general(func, *args, **kwargs) elif args or kwargs: - result = self._aggregate_generic(func, *args, **kwargs) + result = self._aggregate_frame(func, *args, **kwargs) else: # try to treat as if we are passing a list @@ -875,7 +880,7 @@ def aggregate(self, func=None, *args, **kwargs): [func], _level=_level, _axis=self.axis ) except Exception: - result = self._aggregate_generic(func) + result = self._aggregate_frame(func) else: result.columns = Index( result.columns.levels[0], name=self._selected_obj.columns.name @@ -999,7 +1004,7 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True, min_count=-1): return new_items, new_blocks - def _aggregate_generic(self, func, *args, **kwargs): + def _aggregate_frame(self, func, *args, **kwargs): if self.grouper.nkeys != 1: raise AssertionError("Number of keys must be 1") @@ -1022,7 +1027,7 @@ def _aggregate_generic(self, func, *args, **kwargs): wrapper = lambda x: func(x, *args, **kwargs) result[name] = data.apply(wrapper, axis=axis) - return self._wrap_generic_output(result, obj) + return self._wrap_frame_output(result, obj) def _aggregate_item_by_item(self, func, *args, **kwargs): # only for axis==0 @@ -1506,7 +1511,7 @@ def _gotitem(self, key, ndim, subset=None): raise AssertionError("invalid ndim for _gotitem") - def _wrap_generic_output(self, result, obj): + def _wrap_frame_output(self, result, obj): result_index = self.grouper.levels[0] if self.axis == 0: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index cb56f7b8d535b..a3808d1e85e33 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -747,7 +747,7 @@ def _python_apply_general(self, f): ) def _iterate_slices(self): - yield self._selection_name, self._selected_obj + raise AbstractMethodError(self) def transform(self, func, *args, **kwargs): raise AbstractMethodError(self) @@ -872,6 +872,12 @@ def _cython_transform(self, how, numeric_only=True, **kwargs): def _wrap_aggregated_output(self, output, names=None): raise AbstractMethodError(self) + def _wrap_transformed_output(self, output, names=None): + raise AbstractMethodError(self) + + def _wrap_applied_output(self, keys, values, not_indexed_same=False): + raise AbstractMethodError(self) + def _cython_agg_general(self, how, alt=None, numeric_only=True, min_count=-1): output = {} for name, obj in self._iterate_slices(): @@ -922,9 +928,6 @@ def _python_agg_general(self, func, *args, **kwargs): return self._wrap_aggregated_output(output) - def _wrap_applied_output(self, *args, **kwargs): - raise AbstractMethodError(self) - def _concat_objects(self, keys, values, not_indexed_same=False): from pandas.core.reshape.concat import concat