Skip to content

ENH: All args and kwargs to generic expanding/rolling apply. #6289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Improvements to existing features

- pd.read_clipboard will, if 'sep' is unspecified, try to detect data copied from a spreadsheet
and parse accordingly. (:issue:`6223`)
- pd.expanding_apply and pd.rolling_apply now take args and kwargs that are passed on to the func.

.. _release.bug_fixes-0.14.0:

Expand Down
7 changes: 4 additions & 3 deletions pandas/algos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int win,
return output

def roll_generic(ndarray[float64_t, cast=True] input, int win,
int minp, object func):
int minp, object func, object args, object kwargs):
cdef ndarray[double_t] output, counts, bufarr
cdef Py_ssize_t i, n
cdef float64_t *buf
Expand All @@ -1652,15 +1652,16 @@ def roll_generic(ndarray[float64_t, cast=True] input, int win,
n = len(input)
for i from 0 <= i < int_min(win, n):
if counts[i] >= minp:
output[i] = func(input[int_max(i - win + 1, 0) : i + 1])
output[i] = func(input[int_max(i - win + 1, 0) : i + 1], *args,
**kwargs)
else:
output[i] = NaN

for i from win <= i < n:
buf = buf + 1
bufarr.data = <char*> buf
if counts[i] >= minp:
output[i] = func(bufarr)
output[i] = func(bufarr, *args, **kwargs)
else:
output[i] = NaN

Expand Down
55 changes: 35 additions & 20 deletions pandas/stats/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def rolling_count(arg, window, freq=None, center=False, time_rule=None):
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

Returns
-------
rolling_count : type of caller
Expand Down Expand Up @@ -255,8 +255,8 @@ def rolling_corr_pairwise(df, window, min_periods=None):
return Panel.from_dict(all_results).swapaxes('items', 'major')


def _rolling_moment(arg, window, func, minp, axis=0, freq=None,
center=False, time_rule=None, **kwargs):
def _rolling_moment(arg, window, func, minp, axis=0, freq=None, center=False,
time_rule=None, args=(), kwargs={}, **kwds):
"""
Rolling statistical measure using supplied function. Designed to be
used with passed-in Cython array-based functions.
Expand All @@ -274,13 +274,18 @@ def _rolling_moment(arg, window, func, minp, axis=0, freq=None,
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

args : tuple
Passed on to func
kwargs : dict
Passed on to func

Returns
-------
y : type of input
"""
arg = _conv_timerule(arg, freq, time_rule)
calc = lambda x: func(x, window, minp=minp, **kwargs)
calc = lambda x: func(x, window, minp=minp, args=args, kwargs=kwargs,
**kwds)
return_hook, values = _process_data_structure(arg)
# actually calculate the moment. Faster way to do this?
if values.ndim > 1:
Expand Down Expand Up @@ -509,7 +514,7 @@ def _rolling_func(func, desc, check_minp=_use_window):
@wraps(func)
def f(arg, window, min_periods=None, freq=None, center=False,
time_rule=None, **kwargs):
def call_cython(arg, window, minp, **kwds):
def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
minp = check_minp(minp, window)
return func(arg, window, minp, **kwds)
return _rolling_moment(arg, window, call_cython, min_periods,
Expand Down Expand Up @@ -551,21 +556,21 @@ def rolling_quantile(arg, window, quantile, min_periods=None, freq=None,
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

Returns
-------
y : type of input argument
"""

def call_cython(arg, window, minp):
def call_cython(arg, window, minp, args=(), kwargs={}):
minp = _use_window(minp, window)
return algos.roll_quantile(arg, window, minp, quantile)
return _rolling_moment(arg, window, call_cython, min_periods,
freq=freq, center=center, time_rule=time_rule)


def rolling_apply(arg, window, func, min_periods=None, freq=None,
center=False, time_rule=None):
center=False, time_rule=None, args=(), kwargs={}):
"""Generic moving function application

Parameters
Expand All @@ -581,16 +586,21 @@ def rolling_apply(arg, window, func, min_periods=None, freq=None,
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

args : tuple
Passed on to func
kwargs : dict
Passed on to func

Returns
-------
y : type of input argument
"""
def call_cython(arg, window, minp):
def call_cython(arg, window, minp, args, kwargs):
minp = _use_window(minp, window)
return algos.roll_generic(arg, window, minp, func)
return algos.roll_generic(arg, window, minp, func, args, kwargs)
return _rolling_moment(arg, window, call_cython, min_periods,
freq=freq, center=center, time_rule=time_rule)
freq=freq, center=center, time_rule=time_rule,
args=args, kwargs=kwargs)


def rolling_window(arg, window=None, win_type=None, min_periods=None,
Expand Down Expand Up @@ -618,7 +628,7 @@ def rolling_window(arg, window=None, win_type=None, min_periods=None,
If True computes weighted mean, else weighted sum
time_rule : Legacy alias for freq
axis : {0, 1}, default 0

Returns
-------
y : type of input argument
Expand Down Expand Up @@ -703,7 +713,7 @@ def f(arg, min_periods=1, freq=None, center=False, time_rule=None,
**kwargs):
window = len(arg)

def call_cython(arg, window, minp, **kwds):
def call_cython(arg, window, minp, args=(), kwargs={}, **kwds):
minp = check_minp(minp, window)
return func(arg, window, minp, **kwds)
return _rolling_moment(arg, window, call_cython, min_periods,
Expand Down Expand Up @@ -744,7 +754,7 @@ def expanding_count(arg, freq=None, center=False, time_rule=None):
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

Returns
-------
expanding_count : type of caller
Expand All @@ -768,7 +778,7 @@ def expanding_quantile(arg, quantile, min_periods=1, freq=None,
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

Returns
-------
y : type of input argument
Expand Down Expand Up @@ -818,7 +828,7 @@ def expanding_corr_pairwise(df, min_periods=1):


def expanding_apply(arg, func, min_periods=1, freq=None, center=False,
time_rule=None):
time_rule=None, args=(), kwargs={}):
"""Generic expanding function application

Parameters
Expand All @@ -833,11 +843,16 @@ def expanding_apply(arg, func, min_periods=1, freq=None, center=False,
center : boolean, default False
Whether the label should correspond with center of window
time_rule : Legacy alias for freq

args : tuple
Passed on to func
kwargs : dict
Passed on to func

Returns
-------
y : type of input argument
"""
window = len(arg)
return rolling_apply(arg, window, func, min_periods=min_periods, freq=freq,
center=center, time_rule=time_rule)
center=center, time_rule=time_rule, args=args,
kwargs=kwargs)
15 changes: 15 additions & 0 deletions pandas/stats/tests/test_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,21 @@ def expanding_mean(x, min_periods=1, freq=None):
freq=freq)
self._check_expanding(expanding_mean, np.mean)

def test_expanding_apply_args_kwargs(self):
def mean_w_arg(x, const):
return np.mean(x) + const

df = DataFrame(np.random.rand(20, 3))

expected = mom.expanding_apply(df, np.mean) + 20.

assert_frame_equal(mom.expanding_apply(df, mean_w_arg, args=(20,)),
expected)
assert_frame_equal(mom.expanding_apply(df, mean_w_arg,
kwargs={'const' : 20}),
expected)


def test_expanding_corr(self):
A = self.series.dropna()
B = (A + randn(len(A)))[:-5]
Expand Down