Skip to content

Commit 1e9c0a6

Browse files
jseaboldjreback
authored andcommitted
ENH: rolling_/expanding_apply take args, kwargs for func
DOC: Add rolling_apply/expanding_apply ENH to release notes. BUG: Use dummy args to accomodate most general case.
1 parent 31505b3 commit 1e9c0a6

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

doc/source/release.rst

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Improvements to existing features
6565

6666
- pd.read_clipboard will, if 'sep' is unspecified, try to detect data copied from a spreadsheet
6767
and parse accordingly. (:issue:`6223`)
68+
- pd.expanding_apply and pd.rolling_apply now take args and kwargs that are passed on to
69+
the func (:issue:`6289`)
6870

6971
.. _release.bug_fixes-0.14.0:
7072

pandas/algos.pyx

+4-3
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,7 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int win,
16271627
return output
16281628

16291629
def roll_generic(ndarray[float64_t, cast=True] input, int win,
1630-
int minp, object func):
1630+
int minp, object func, object args, object kwargs):
16311631
cdef ndarray[double_t] output, counts, bufarr
16321632
cdef Py_ssize_t i, n
16331633
cdef float64_t *buf
@@ -1652,15 +1652,16 @@ def roll_generic(ndarray[float64_t, cast=True] input, int win,
16521652
n = len(input)
16531653
for i from 0 <= i < int_min(win, n):
16541654
if counts[i] >= minp:
1655-
output[i] = func(input[int_max(i - win + 1, 0) : i + 1])
1655+
output[i] = func(input[int_max(i - win + 1, 0) : i + 1], *args,
1656+
**kwargs)
16561657
else:
16571658
output[i] = NaN
16581659

16591660
for i from win <= i < n:
16601661
buf = buf + 1
16611662
bufarr.data = <char*> buf
16621663
if counts[i] >= minp:
1663-
output[i] = func(bufarr)
1664+
output[i] = func(bufarr, *args, **kwargs)
16641665
else:
16651666
output[i] = NaN
16661667

pandas/stats/moments.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def rolling_count(arg, window, freq=None, center=False, time_rule=None):
141141
center : boolean, default False
142142
Whether the label should correspond with center of window
143143
time_rule : Legacy alias for freq
144-
144+
145145
Returns
146146
-------
147147
rolling_count : type of caller
@@ -255,8 +255,8 @@ def rolling_corr_pairwise(df, window, min_periods=None):
255255
return Panel.from_dict(all_results).swapaxes('items', 'major')
256256

257257

258-
def _rolling_moment(arg, window, func, minp, axis=0, freq=None,
259-
center=False, time_rule=None, **kwargs):
258+
def _rolling_moment(arg, window, func, minp, axis=0, freq=None, center=False,
259+
time_rule=None, args=(), kwargs={}, **kwds):
260260
"""
261261
Rolling statistical measure using supplied function. Designed to be
262262
used with passed-in Cython array-based functions.
@@ -274,13 +274,18 @@ def _rolling_moment(arg, window, func, minp, axis=0, freq=None,
274274
center : boolean, default False
275275
Whether the label should correspond with center of window
276276
time_rule : Legacy alias for freq
277-
277+
args : tuple
278+
Passed on to func
279+
kwargs : dict
280+
Passed on to func
281+
278282
Returns
279283
-------
280284
y : type of input
281285
"""
282286
arg = _conv_timerule(arg, freq, time_rule)
283-
calc = lambda x: func(x, window, minp=minp, **kwargs)
287+
calc = lambda x: func(x, window, minp=minp, args=args, kwargs=kwargs,
288+
**kwds)
284289
return_hook, values = _process_data_structure(arg)
285290
# actually calculate the moment. Faster way to do this?
286291
if values.ndim > 1:
@@ -509,7 +514,7 @@ def _rolling_func(func, desc, check_minp=_use_window):
509514
@wraps(func)
510515
def f(arg, window, min_periods=None, freq=None, center=False,
511516
time_rule=None, **kwargs):
512-
def call_cython(arg, window, minp, **kwds):
517+
def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
513518
minp = check_minp(minp, window)
514519
return func(arg, window, minp, **kwds)
515520
return _rolling_moment(arg, window, call_cython, min_periods,
@@ -551,21 +556,21 @@ def rolling_quantile(arg, window, quantile, min_periods=None, freq=None,
551556
center : boolean, default False
552557
Whether the label should correspond with center of window
553558
time_rule : Legacy alias for freq
554-
559+
555560
Returns
556561
-------
557562
y : type of input argument
558563
"""
559564

560-
def call_cython(arg, window, minp):
565+
def call_cython(arg, window, minp, args=(), kwargs={}):
561566
minp = _use_window(minp, window)
562567
return algos.roll_quantile(arg, window, minp, quantile)
563568
return _rolling_moment(arg, window, call_cython, min_periods,
564569
freq=freq, center=center, time_rule=time_rule)
565570

566571

567572
def rolling_apply(arg, window, func, min_periods=None, freq=None,
568-
center=False, time_rule=None):
573+
center=False, time_rule=None, args=(), kwargs={}):
569574
"""Generic moving function application
570575
571576
Parameters
@@ -581,16 +586,21 @@ def rolling_apply(arg, window, func, min_periods=None, freq=None,
581586
center : boolean, default False
582587
Whether the label should correspond with center of window
583588
time_rule : Legacy alias for freq
584-
589+
args : tuple
590+
Passed on to func
591+
kwargs : dict
592+
Passed on to func
593+
585594
Returns
586595
-------
587596
y : type of input argument
588597
"""
589-
def call_cython(arg, window, minp):
598+
def call_cython(arg, window, minp, args, kwargs):
590599
minp = _use_window(minp, window)
591-
return algos.roll_generic(arg, window, minp, func)
600+
return algos.roll_generic(arg, window, minp, func, args, kwargs)
592601
return _rolling_moment(arg, window, call_cython, min_periods,
593-
freq=freq, center=center, time_rule=time_rule)
602+
freq=freq, center=center, time_rule=time_rule,
603+
args=args, kwargs=kwargs)
594604

595605

596606
def rolling_window(arg, window=None, win_type=None, min_periods=None,
@@ -618,7 +628,7 @@ def rolling_window(arg, window=None, win_type=None, min_periods=None,
618628
If True computes weighted mean, else weighted sum
619629
time_rule : Legacy alias for freq
620630
axis : {0, 1}, default 0
621-
631+
622632
Returns
623633
-------
624634
y : type of input argument
@@ -703,7 +713,7 @@ def f(arg, min_periods=1, freq=None, center=False, time_rule=None,
703713
**kwargs):
704714
window = len(arg)
705715

706-
def call_cython(arg, window, minp, **kwds):
716+
def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
707717
minp = check_minp(minp, window)
708718
return func(arg, window, minp, **kwds)
709719
return _rolling_moment(arg, window, call_cython, min_periods,
@@ -744,7 +754,7 @@ def expanding_count(arg, freq=None, center=False, time_rule=None):
744754
center : boolean, default False
745755
Whether the label should correspond with center of window
746756
time_rule : Legacy alias for freq
747-
757+
748758
Returns
749759
-------
750760
expanding_count : type of caller
@@ -768,7 +778,7 @@ def expanding_quantile(arg, quantile, min_periods=1, freq=None,
768778
center : boolean, default False
769779
Whether the label should correspond with center of window
770780
time_rule : Legacy alias for freq
771-
781+
772782
Returns
773783
-------
774784
y : type of input argument
@@ -818,7 +828,7 @@ def expanding_corr_pairwise(df, min_periods=1):
818828

819829

820830
def expanding_apply(arg, func, min_periods=1, freq=None, center=False,
821-
time_rule=None):
831+
time_rule=None, args=(), kwargs={}):
822832
"""Generic expanding function application
823833
824834
Parameters
@@ -833,11 +843,16 @@ def expanding_apply(arg, func, min_periods=1, freq=None, center=False,
833843
center : boolean, default False
834844
Whether the label should correspond with center of window
835845
time_rule : Legacy alias for freq
836-
846+
args : tuple
847+
Passed on to func
848+
kwargs : dict
849+
Passed on to func
850+
837851
Returns
838852
-------
839853
y : type of input argument
840854
"""
841855
window = len(arg)
842856
return rolling_apply(arg, window, func, min_periods=min_periods, freq=freq,
843-
center=center, time_rule=time_rule)
857+
center=center, time_rule=time_rule, args=args,
858+
kwargs=kwargs)

pandas/stats/tests/test_moments.py

+15
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,21 @@ def expanding_mean(x, min_periods=1, freq=None):
694694
freq=freq)
695695
self._check_expanding(expanding_mean, np.mean)
696696

697+
def test_expanding_apply_args_kwargs(self):
698+
def mean_w_arg(x, const):
699+
return np.mean(x) + const
700+
701+
df = DataFrame(np.random.rand(20, 3))
702+
703+
expected = mom.expanding_apply(df, np.mean) + 20.
704+
705+
assert_frame_equal(mom.expanding_apply(df, mean_w_arg, args=(20,)),
706+
expected)
707+
assert_frame_equal(mom.expanding_apply(df, mean_w_arg,
708+
kwargs={'const' : 20}),
709+
expected)
710+
711+
697712
def test_expanding_corr(self):
698713
A = self.series.dropna()
699714
B = (A + randn(len(A)))[:-5]

0 commit comments

Comments
 (0)