Skip to content

Add Numba to rolling.apply #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
3 changes: 2 additions & 1 deletion asv_bench/benchmarks/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class Apply:
["DataFrame", "Series"],
[10, 1000],
["int", "float"],
[sum, np.sum, lambda x: np.sum(x) + 5],
# TODO: numba doesn't support builtin.sum
[np.sum, lambda x: np.sum(x) + 5],
[True, False],
)
param_names = ["contructor", "window", "dtype", "function", "raw"]
Expand Down
1 change: 0 additions & 1 deletion pandas/core/window/aggregators/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
This implementation mimics what we currently do in cython except the
calculation of window bounds is independent of the aggregation routine.
"""

import numba
import numpy as np

Expand Down
99 changes: 70 additions & 29 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from datetime import timedelta
from functools import partial
from textwrap import dedent
from typing import Callable, List, Optional, Set, Union
from typing import Callable, Dict, List, Optional, Set, Union
import warnings

import numba
import numpy as np

import pandas._libs.window as libwindow
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self.win_freq = None
self.axis = obj._get_axis_number(axis) if axis is not None else None
self.validate()
self._apply_func_cache = dict() # type: Dict

@property
def _constructor(self):
Expand Down Expand Up @@ -431,7 +433,13 @@ def _apply(
-------
y : type of input
"""
use_numba = kwargs.pop("use_numba", None)
use_numba = kwargs.pop("use_numba", False)
floor = kwargs.pop("floor", None)
if not use_numba:
# apply stores use_numba and floor in kwargs[kwargs]
extra_kwargs = kwargs.pop("kwargs", {})
use_numba = extra_kwargs.get("use_numba", False)
floor = extra_kwargs.get("floor", None)

if center is None:
center = self.center
Expand Down Expand Up @@ -487,12 +495,16 @@ def _apply(
window,
_use_window(self.min_periods, window),
len(values) + offset,
floor,
)
else:
minimum_periods = _check_min_periods(
self.min_periods or 1, self.min_periods, len(values) + offset
self.min_periods or 1,
self.min_periods,
len(values) + offset,
floor,
)
func = partial( # type: ignore
func_partial = partial( # type: ignore
func, begin=start, end=end, minimum_periods=minimum_periods
)

Expand All @@ -510,7 +522,7 @@ def _apply(
cfunc, check_minp, index_as_array, **kwargs
)

func = partial( # type: ignore
func_partial = partial( # type: ignore
func,
window=window,
min_periods=self.min_periods,
Expand All @@ -520,12 +532,12 @@ def _apply(
if additional_nans is not None:

def calc(x):
return func(np.concatenate((x, additional_nans)))
return func_partial(np.concatenate((x, additional_nans)))

else:

def calc(x):
return func(x)
return func_partial(x)

with np.errstate(all="ignore"):
if values.ndim > 1:
Expand All @@ -534,6 +546,9 @@ def calc(x):
result = calc(values)
result = np.asarray(result)

if use_numba:
self._apply_func_cache[name] = func

if center:
result = self._center_window(result, window)

Expand Down Expand Up @@ -1106,12 +1121,8 @@ def count(self):
)

def apply(self, func, raw=None, args=(), kwargs={}):
from pandas import Series

kwargs.pop("_level", None)
window = self._get_window()
offset = _offset(window, self.center)
index_as_array = self._get_index()

# TODO: default is for backward compat
# change to False in the future
Expand All @@ -1127,24 +1138,54 @@ def apply(self, func, raw=None, args=(), kwargs={}):
)
raw = True

def f(arg, window, min_periods, closed):
minp = _use_window(min_periods, window)
if not raw:
arg = Series(arg, index=self.obj.index)
return libwindow.roll_generic(
arg,
window,
minp,
index_as_array,
closed,
offset,
func,
raw,
args,
kwargs,
)

return self._apply(f, func, args=args, kwargs=kwargs, center=False, raw=raw)
# Numba doesn't support kwargs in nopython mode
# https://github.com/numba/numba/issues/2916
if func not in self._apply_func_cache:

def make_rolling_apply(func):
@numba.generated_jit(nopython=True)
def numba_func(window, *_args):
if getattr(np, func.__name__, False) is func:

def impl(window, *_args):
return func(window, *_args)

return impl
else:
jf = numba.njit(func)

def impl(window, *_args):
return jf(window, *_args)

return impl

@numba.njit
def roll_apply(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
):
result = np.empty(len(begin))
for i, (start, stop) in enumerate(zip(begin, end)):
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result

return roll_apply

rolling_apply = make_rolling_apply(func)
else:
rolling_apply = self._apply_func_cache[func]
kwargs["use_numba"] = True
kwargs["floor"] = 0
return self._apply(
rolling_apply, func, args=args, kwargs=kwargs, center=False, raw=raw
)

def sum(self, *args, **kwargs):
nv.validate_window_func("sum", args, kwargs)
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/window/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_agg(self):
expected.columns = pd.MultiIndex.from_tuples(exp_cols)
tm.assert_frame_equal(result, expected, check_like=True)

@pytest.mark.xfail(reason="TypingError: numba doesn't support kwarg for std")
def test_agg_apply(self, raw):

# passed lambda
Expand Down
79 changes: 54 additions & 25 deletions pandas/tests/window/test_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,10 @@ def test_rolling_quantile_param(self):
with pytest.raises(TypeError):
ser.rolling(3).quantile("foo")

@pytest.mark.xfail(
reason="unsupported controlflow due to return/raise statements "
"inside with block"
)
def test_rolling_apply(self, raw):
# suppress warnings about empty slices, as we are deliberately testing
# with a 0-length Series
Expand Down Expand Up @@ -679,6 +683,10 @@ def test_rolling_apply_out_of_bounds(self, raw):
expected = pd.Series([1, 3, 6, 10], dtype=float)
tm.assert_almost_equal(result, expected)

@pytest.mark.xfail(
reason="Untyped global name 'df': "
"cannot determine Numba type of <class 'pandas.core.frame.DataFrame'>"
)
@pytest.mark.parametrize("window", [2, "2s"])
def test_rolling_apply_with_pandas_objects(self, window):
# 5071
Expand Down Expand Up @@ -1629,6 +1637,10 @@ def _ewma(s, com, min_periods, adjust, ignore_na):
),
)

@pytest.mark.xfail(
reason="Untyped global name 'Series': cannot determine "
"Numba type of <class 'type'>"
)
@pytest.mark.slow
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
def test_expanding_consistency(self, min_periods):
Expand Down Expand Up @@ -1701,6 +1713,10 @@ def test_expanding_consistency(self, min_periods):
if name in ["sum", "prod"]:
tm.assert_equal(expanding_f_result, expanding_apply_f_result)

@pytest.mark.xfail(
reason="Untyped global name 'Series': cannot determine Numba type of "
"<class 'type'>"
)
@pytest.mark.slow
@pytest.mark.parametrize(
"window,min_periods,center", list(_rolling_consistency_cases())
Expand Down Expand Up @@ -1977,6 +1993,7 @@ def func(A, B, com, **kwargs):
with pytest.raises(Exception, match=msg):
func(A, randn(50), 20, min_periods=5)

@pytest.mark.xfail(reason="Use of unsupported opcode (SETUP_EXCEPT) found")
def test_expanding_apply_args_kwargs(self, raw):
def mean_w_arg(x, const):
return np.mean(x) + const
Expand Down Expand Up @@ -2118,8 +2135,18 @@ def test_rolling_corr_diff_length(self):
lambda x: x.rolling(window=10, min_periods=5).kurt(),
lambda x: x.rolling(window=10, min_periods=5).quantile(quantile=0.5),
lambda x: x.rolling(window=10, min_periods=5).median(),
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
pytest.param(
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
marks=pytest.mark.xfail(
reason="https://github.com/numba/numba/issues/4587"
),
),
pytest.param(
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
marks=pytest.mark.xfail(
reason="https://github.com/numba/numba/issues/4587"
),
),
lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(),
],
)
Expand Down Expand Up @@ -2164,17 +2191,9 @@ def test_rolling_functions_window_non_shrinkage_binary(self):
df_result = f(df)
tm.assert_frame_equal(df_result, df_expected)

def test_moment_functions_zero_length(self):
# GH 8056
s = Series()
s_expected = s
df1 = DataFrame()
df1_expected = df1
df2 = DataFrame(columns=["a"])
df2["a"] = df2["a"].astype("float64")
df2_expected = df2

functions = [
@pytest.mark.parametrize(
"f",
[
lambda x: x.expanding().count(),
lambda x: x.expanding(min_periods=5).cov(x, pairwise=False),
lambda x: x.expanding(min_periods=5).corr(x, pairwise=False),
Expand Down Expand Up @@ -2206,21 +2225,31 @@ def test_moment_functions_zero_length(self):
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(),
]
for f in functions:
try:
s_result = f(s)
tm.assert_series_equal(s_result, s_expected)
],
)
def test_moment_functions_zero_length(self, f):
# GH 8056
s = Series()
s_expected = s
df1 = DataFrame()
df1_expected = df1
df2 = DataFrame(columns=["a"])
df2["a"] = df2["a"].astype("float64")
df2_expected = df2

df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)
try:
s_result = f(s)
tm.assert_series_equal(s_result, s_expected)

df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)
except (ImportError):
df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)

# scipy needed for rolling_window
continue
df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)
except (ImportError):

# scipy needed for rolling_window
pass

def test_moment_functions_zero_length_pairwise(self):

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/window/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_constructor_with_timedelta_window(self, window):
expected = df.rolling("3D").sum()
tm.assert_frame_equal(result, expected)

@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/4587")
@pytest.mark.parametrize("window", [timedelta(days=3), pd.Timedelta(days=3), "3D"])
def test_constructor_timedelta_window_and_minperiods(self, window, raw):
# GH 15305
Expand Down