Skip to content

Commit 4eb5e51

Browse files
committed
Merge pull request #7975 from jreback/transform_perf
PERF: perf improvements for Series.transform (revised) (GH6496)
2 parents 914b0f3 + fe55b89 commit 4eb5e51

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

doc/source/v0.15.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ previously results in ``Exception`` or ``TypeError`` (:issue:`7812`)
162162
didx
163163
didx.tz_localize(None)
164164

165-
- ``DataFrame.tz_localize`` and ``DataFrame.tz_convert`` now accepts an optional ``level`` argument
165+
- ``DataFrame.tz_localize`` and ``DataFrame.tz_convert`` now accepts an optional ``level`` argument
166166
for localizing a specific level of a MultiIndex (:issue:`7846`)
167167

168168
.. _whatsnew_0150.refactoring:
@@ -302,6 +302,7 @@ Performance
302302

303303
- Performance improvements in ``DatetimeIndex.__iter__`` to allow faster iteration (:issue:`7683`)
304304
- Performance improvements in ``Period`` creation (and ``PeriodIndex`` setitem) (:issue:`5155`)
305+
- Improvements in Series.transform for significant performance gains (revised) (:issue:`6496`)
305306

306307

307308

pandas/core/groupby.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -2270,14 +2270,21 @@ def transform(self, func, *args, **kwargs):
22702270
-------
22712271
transformed : Series
22722272
"""
2273-
dtype = self._selected_obj.dtype
22742273

2274+
# if string function
22752275
if isinstance(func, compat.string_types):
2276-
wrapper = lambda x: getattr(x, func)(*args, **kwargs)
2277-
else:
2278-
wrapper = lambda x: func(x, *args, **kwargs)
2276+
return self._transform_fast(lambda : getattr(self, func)(*args, **kwargs))
2277+
2278+
# do we have a cython function
2279+
cyfunc = _intercept_cython(func)
2280+
if cyfunc and not args and not kwargs:
2281+
return self._transform_fast(cyfunc)
22792282

2283+
# reg transform
2284+
dtype = self._selected_obj.dtype
22802285
result = self._selected_obj.values.copy()
2286+
2287+
wrapper = lambda x: func(x, *args, **kwargs)
22812288
for i, (name, group) in enumerate(self):
22822289

22832290
object.__setattr__(group, 'name', name)
@@ -2302,6 +2309,29 @@ def transform(self, func, *args, **kwargs):
23022309
index=self._selected_obj.index,
23032310
name=self._selected_obj.name)
23042311

2312+
def _transform_fast(self, func):
2313+
"""
2314+
fast version of transform, only applicable to builtin/cythonizable functions
2315+
"""
2316+
if isinstance(func, compat.string_types):
2317+
func = getattr(self,func)
2318+
values = func().values
2319+
counts = self.count().values
2320+
values = np.repeat(values, counts)
2321+
2322+
# the values/counts are repeated according to the group index
2323+
indices = self.indices
2324+
2325+
# shortcut of we have an already ordered grouper
2326+
if Index(self.grouper.group_info[0]).is_monotonic:
2327+
result = Series(values, index=self.obj.index)
2328+
else:
2329+
index = Index(np.concatenate([ indices[v] for v in self.grouper.result_index ]))
2330+
result = Series(values, index=index).sort_index()
2331+
result.index = self.obj.index
2332+
2333+
return result
2334+
23052335
def filter(self, func, dropna=True, *args, **kwargs):
23062336
"""
23072337
Return a copy of a Series excluding elements from groups that

pandas/tests/test_groupby.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,20 @@ def test_transform(self):
795795
transformed = grouped.transform(lambda x: x * x.sum())
796796
self.assertEqual(transformed[7], 12)
797797

798+
def test_transform_fast(self):
799+
800+
df = DataFrame( { 'id' : np.arange( 100000 ) / 3,
801+
'val': np.random.randn( 100000) } )
802+
803+
grp=df.groupby('id')['val']
804+
805+
expected = pd.Series(np.repeat(grp.mean().values, grp.count().values),index=df.index)
806+
result = grp.transform(np.mean)
807+
assert_series_equal(result,expected)
808+
809+
result = grp.transform('mean')
810+
assert_series_equal(result,expected)
811+
798812
def test_transform_broadcast(self):
799813
grouped = self.ts.groupby(lambda x: x.month)
800814
result = grouped.transform(np.mean)
@@ -858,12 +872,14 @@ def test_transform_select_columns(self):
858872
assert_frame_equal(result, expected)
859873

860874
def test_transform_exclude_nuisance(self):
875+
876+
# this also tests orderings in transform between
877+
# series/frame to make sure its consistent
861878
expected = {}
862879
grouped = self.df.groupby('A')
863880
expected['C'] = grouped['C'].transform(np.mean)
864881
expected['D'] = grouped['D'].transform(np.mean)
865882
expected = DataFrame(expected)
866-
867883
result = self.df.groupby('A').transform(np.mean)
868884

869885
assert_frame_equal(result, expected)

vb_suite/groupby.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -444,5 +444,13 @@ def f(g):
444444
445445
df = DataFrame({ 'signal' : np.random.rand(N)})
446446
"""
447-
448447
groupby_transform_series = Benchmark("df['signal'].groupby(g).transform(np.mean)", setup)
448+
449+
setup = common_setup + """
450+
np.random.seed(0)
451+
452+
df=DataFrame( { 'id' : np.arange( 100000 ) / 3,
453+
'val': np.random.randn( 100000) } )
454+
"""
455+
456+
groupby_transform_series2 = Benchmark("df.groupby('id')['val'].transform(np.mean)", setup)

0 commit comments

Comments
 (0)