Skip to content

Added cast blacklist for certain transform agg funcs #19355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.23.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ Groupby/Resample/Rolling
- Fixed regression in :func:`DataFrame.groupby` which would not emit an error when called with a tuple key not in the index (:issue:`18798`)
- Bug in :func:`DataFrame.resample` which silently ignored unsupported (or mistyped) options for ``label``, ``closed`` and ``convention`` (:issue:`19303`)
- Bug in :func:`DataFrame.groupby` where tuples were interpreted as lists of keys rather than as keys (:issue:`17979`, :issue:`18249`)
- Bug in ``transform`` where particular aggregation functions were being incorrectly cast to match the dtype(s) of the grouped data (:issue:`19200`)
-

Sparse
Expand Down
29 changes: 23 additions & 6 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@
_cython_transforms = frozenset(['cumprod', 'cumsum', 'shift',
'cummin', 'cummax'])

_cython_cast_blacklist = frozenset(['rank', 'count', 'size'])


class Grouper(object):
"""
Expand Down Expand Up @@ -965,6 +967,21 @@ def _try_cast(self, result, obj, numeric_only=False):

return result

def _transform_should_cast(self, func_nm):
"""
Parameters:
-----------
func_nm: str
The name of the aggregation function being performed

Returns:
--------
bool
Whether transform should attempt to cast the result of aggregation
"""
return (self.size().fillna(0) > 0).any() and (func_nm not in
_cython_cast_blacklist)

def _cython_transform(self, how, numeric_only=True):
output = collections.OrderedDict()
for name, obj in self._iterate_slices():
Expand Down Expand Up @@ -3333,7 +3350,7 @@ def transform(self, func, *args, **kwargs):
else:
# cythonized aggregation and merge
return self._transform_fast(
lambda: getattr(self, func)(*args, **kwargs))
lambda: getattr(self, func)(*args, **kwargs), func)

# reg transform
klass = self._selected_obj.__class__
Expand Down Expand Up @@ -3364,7 +3381,7 @@ def transform(self, func, *args, **kwargs):
result.index = self._selected_obj.index
return result

def _transform_fast(self, func):
def _transform_fast(self, func, func_nm):
"""
fast version of transform, only applicable to
builtin/cythonizable functions
Expand All @@ -3373,7 +3390,7 @@ def _transform_fast(self, func):
func = getattr(self, func)

ids, _, ngroup = self.grouper.group_info
cast = (self.size().fillna(0) > 0).any()
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(func().values, ids)
if cast:
out = self._try_cast(out, self.obj)
Expand Down Expand Up @@ -4127,15 +4144,15 @@ def transform(self, func, *args, **kwargs):
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)

return self._transform_fast(result, obj)
return self._transform_fast(result, obj, func)

def _transform_fast(self, result, obj):
def _transform_fast(self, result, obj, func_nm):
"""
Fast transform path for aggregations
"""
# if there were groups with no observations (Categorical only?)
# try casting data to original dtype
cast = (self.size().fillna(0) > 0).any()
cast = self._transform_should_cast(func_nm)

# for each col, reshape to to size of original frame
# by take operation
Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,28 @@ def test_transform_with_non_scalar_group(self):
'group.*',
df.groupby(axis=1, level=1).transform,
lambda z: z.div(z.sum(axis=1), axis=0))

@pytest.mark.parametrize('cols,exp,comp_func', [
('a', pd.Series([1, 1, 1], name='a'), tm.assert_series_equal),
(['a', 'c'], pd.DataFrame({'a': [1, 1, 1], 'c': [1, 1, 1]}),
tm.assert_frame_equal)
])
@pytest.mark.parametrize('agg_func', [
'count', 'rank', 'size'])
def test_transform_numeric_ret(self, cols, exp, comp_func, agg_func):
if agg_func == 'size' and isinstance(cols, list):
pytest.xfail("'size' transformation not supported with "
"NDFrameGroupy")

# GH 19200
df = pd.DataFrame(
{'a': pd.date_range('2018-01-01', periods=3),
'b': range(3),
'c': range(7, 10)})

result = df.groupby('b')[cols].transform(agg_func)

if agg_func == 'rank':
exp = exp.astype('float')

comp_func(result, exp)