From 3f8b7f3e458543d2e65fcfd7e64d07ee4c2afb7b Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Mon, 22 Jan 2018 17:14:35 -0800 Subject: [PATCH] Added cast blacklist for certain transform agg funcs --- doc/source/whatsnew/v0.23.0.txt | 1 + pandas/core/groupby.py | 29 ++++++++++++++++++++------ pandas/tests/groupby/test_transform.py | 25 ++++++++++++++++++++++ 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index 7f697003f44b9..71492154419fb 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -506,6 +506,7 @@ Groupby/Resample/Rolling - Fixed regression in :func:`DataFrame.groupby` which would not emit an error when called with a tuple key not in the index (:issue:`18798`) - Bug in :func:`DataFrame.resample` which silently ignored unsupported (or mistyped) options for ``label``, ``closed`` and ``convention`` (:issue:`19303`) - Bug in :func:`DataFrame.groupby` where tuples were interpreted as lists of keys rather than as keys (:issue:`17979`, :issue:`18249`) +- Bug in ``transform`` where particular aggregation functions were being incorrectly cast to match the dtype(s) of the grouped data (:issue:`19200`) - Sparse diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index fc7a0faef0cf6..2c1deb9db7bba 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -345,6 +345,8 @@ _cython_transforms = frozenset(['cumprod', 'cumsum', 'shift', 'cummin', 'cummax']) +_cython_cast_blacklist = frozenset(['rank', 'count', 'size']) + class Grouper(object): """ @@ -965,6 +967,21 @@ def _try_cast(self, result, obj, numeric_only=False): return result + def _transform_should_cast(self, func_nm): + """ + Parameters: + ----------- + func_nm: str + The name of the aggregation function being performed + + Returns: + -------- + bool + Whether transform should attempt to cast the result of aggregation + """ + return (self.size().fillna(0) > 0).any() and (func_nm not in + _cython_cast_blacklist) + def _cython_transform(self, how, numeric_only=True): output = collections.OrderedDict() for name, obj in self._iterate_slices(): @@ -3333,7 +3350,7 @@ def transform(self, func, *args, **kwargs): else: # cythonized aggregation and merge return self._transform_fast( - lambda: getattr(self, func)(*args, **kwargs)) + lambda: getattr(self, func)(*args, **kwargs), func) # reg transform klass = self._selected_obj.__class__ @@ -3364,7 +3381,7 @@ def transform(self, func, *args, **kwargs): result.index = self._selected_obj.index return result - def _transform_fast(self, func): + def _transform_fast(self, func, func_nm): """ fast version of transform, only applicable to builtin/cythonizable functions @@ -3373,7 +3390,7 @@ def _transform_fast(self, func): func = getattr(self, func) ids, _, ngroup = self.grouper.group_info - cast = (self.size().fillna(0) > 0).any() + cast = self._transform_should_cast(func_nm) out = algorithms.take_1d(func().values, ids) if cast: out = self._try_cast(out, self.obj) @@ -4127,15 +4144,15 @@ def transform(self, func, *args, **kwargs): if not result.columns.equals(obj.columns): return self._transform_general(func, *args, **kwargs) - return self._transform_fast(result, obj) + return self._transform_fast(result, obj, func) - def _transform_fast(self, result, obj): + def _transform_fast(self, result, obj, func_nm): """ Fast transform path for aggregations """ # if there were groups with no observations (Categorical only?) # try casting data to original dtype - cast = (self.size().fillna(0) > 0).any() + cast = self._transform_should_cast(func_nm) # for each col, reshape to to size of original frame # by take operation diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 8f72da293a50c..4159d0f709a13 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -582,3 +582,28 @@ def test_transform_with_non_scalar_group(self): 'group.*', df.groupby(axis=1, level=1).transform, lambda z: z.div(z.sum(axis=1), axis=0)) + + @pytest.mark.parametrize('cols,exp,comp_func', [ + ('a', pd.Series([1, 1, 1], name='a'), tm.assert_series_equal), + (['a', 'c'], pd.DataFrame({'a': [1, 1, 1], 'c': [1, 1, 1]}), + tm.assert_frame_equal) + ]) + @pytest.mark.parametrize('agg_func', [ + 'count', 'rank', 'size']) + def test_transform_numeric_ret(self, cols, exp, comp_func, agg_func): + if agg_func == 'size' and isinstance(cols, list): + pytest.xfail("'size' transformation not supported with " + "NDFrameGroupy") + + # GH 19200 + df = pd.DataFrame( + {'a': pd.date_range('2018-01-01', periods=3), + 'b': range(3), + 'c': range(7, 10)}) + + result = df.groupby('b')[cols].transform(agg_func) + + if agg_func == 'rank': + exp = exp.astype('float') + + comp_func(result, exp)