Skip to content

Commit 4ab071f

Browse files
mroeschkePuneethaPai
authored andcommitted
CLN: Move rolling helper functions to where they are used (#34269)
1 parent 0da4703 commit 4ab071f

File tree

3 files changed

+106
-90
lines changed

3 files changed

+106
-90
lines changed

pandas/core/window/common.py

-80
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55

66
import numpy as np
77

8-
from pandas.core.dtypes.common import is_integer
98
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
109

11-
import pandas.core.common as com
1210
from pandas.core.generic import _shared_docs
1311
from pandas.core.groupby.base import GroupByMixin
1412
from pandas.core.indexes.api import MultiIndex
@@ -224,75 +222,6 @@ def dataframe_from_int_dict(data, frame_template):
224222
return _flex_binary_moment(arg2, arg1, f)
225223

226224

227-
def _get_center_of_mass(comass, span, halflife, alpha):
228-
valid_count = com.count_not_none(comass, span, halflife, alpha)
229-
if valid_count > 1:
230-
raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
231-
232-
# Convert to center of mass; domain checks ensure 0 < alpha <= 1
233-
if comass is not None:
234-
if comass < 0:
235-
raise ValueError("comass must satisfy: comass >= 0")
236-
elif span is not None:
237-
if span < 1:
238-
raise ValueError("span must satisfy: span >= 1")
239-
comass = (span - 1) / 2.0
240-
elif halflife is not None:
241-
if halflife <= 0:
242-
raise ValueError("halflife must satisfy: halflife > 0")
243-
decay = 1 - np.exp(np.log(0.5) / halflife)
244-
comass = 1 / decay - 1
245-
elif alpha is not None:
246-
if alpha <= 0 or alpha > 1:
247-
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
248-
comass = (1.0 - alpha) / alpha
249-
else:
250-
raise ValueError("Must pass one of comass, span, halflife, or alpha")
251-
252-
return float(comass)
253-
254-
255-
def calculate_center_offset(window):
256-
if not is_integer(window):
257-
window = len(window)
258-
return int((window - 1) / 2.0)
259-
260-
261-
def calculate_min_periods(
262-
window: int,
263-
min_periods: Optional[int],
264-
num_values: int,
265-
required_min_periods: int,
266-
floor: int,
267-
) -> int:
268-
"""
269-
Calculates final minimum periods value for rolling aggregations.
270-
271-
Parameters
272-
----------
273-
window : passed window value
274-
min_periods : passed min periods value
275-
num_values : total number of values
276-
required_min_periods : required min periods per aggregation function
277-
floor : required min periods per aggregation function
278-
279-
Returns
280-
-------
281-
min_periods : int
282-
"""
283-
if min_periods is None:
284-
min_periods = window
285-
else:
286-
min_periods = max(required_min_periods, min_periods)
287-
if min_periods > window:
288-
raise ValueError(f"min_periods {min_periods} must be <= window {window}")
289-
elif min_periods > num_values:
290-
min_periods = num_values + 1
291-
elif min_periods < 0:
292-
raise ValueError("min_periods must be >= 0")
293-
return max(min_periods, floor)
294-
295-
296225
def zsqrt(x):
297226
with np.errstate(all="ignore"):
298227
result = np.sqrt(x)
@@ -317,12 +246,3 @@ def prep_binary(arg1, arg2):
317246
Y = arg2 + 0 * arg1
318247

319248
return X, Y
320-
321-
322-
def get_weighted_roll_func(cfunc: Callable) -> Callable:
323-
def func(arg, window, min_periods=None):
324-
if min_periods is None:
325-
min_periods = len(window)
326-
return cfunc(arg, window, min_periods)
327-
328-
return func

pandas/core/window/ewm.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
from pandas.core.dtypes.generic import ABCDataFrame
1010

1111
from pandas.core.base import DataError
12-
from pandas.core.window.common import (
13-
_doc_template,
14-
_get_center_of_mass,
15-
_shared_docs,
16-
zsqrt,
17-
)
12+
import pandas.core.common as com
13+
from pandas.core.window.common import _doc_template, _shared_docs, zsqrt
1814
from pandas.core.window.rolling import _flex_binary_moment, _Rolling
1915

2016
_bias_template = """
@@ -27,6 +23,34 @@
2723
"""
2824

2925

26+
def get_center_of_mass(comass, span, halflife, alpha) -> float:
27+
valid_count = com.count_not_none(comass, span, halflife, alpha)
28+
if valid_count > 1:
29+
raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
30+
31+
# Convert to center of mass; domain checks ensure 0 < alpha <= 1
32+
if comass is not None:
33+
if comass < 0:
34+
raise ValueError("comass must satisfy: comass >= 0")
35+
elif span is not None:
36+
if span < 1:
37+
raise ValueError("span must satisfy: span >= 1")
38+
comass = (span - 1) / 2.0
39+
elif halflife is not None:
40+
if halflife <= 0:
41+
raise ValueError("halflife must satisfy: halflife > 0")
42+
decay = 1 - np.exp(np.log(0.5) / halflife)
43+
comass = 1 / decay - 1
44+
elif alpha is not None:
45+
if alpha <= 0 or alpha > 1:
46+
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
47+
comass = (1.0 - alpha) / alpha
48+
else:
49+
raise ValueError("Must pass one of comass, span, halflife, or alpha")
50+
51+
return float(comass)
52+
53+
3054
class EWM(_Rolling):
3155
r"""
3256
Provide exponential weighted (EW) functions.
@@ -144,7 +168,7 @@ def __init__(
144168
axis=0,
145169
):
146170
self.obj = obj
147-
self.com = _get_center_of_mass(com, span, halflife, alpha)
171+
self.com = get_center_of_mass(com, span, halflife, alpha)
148172
self.min_periods = min_periods
149173
self.adjust = adjust
150174
self.ignore_na = ignore_na

pandas/core/window/rolling.py

+75-3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@
4444
_doc_template,
4545
_flex_binary_moment,
4646
_shared_docs,
47-
calculate_center_offset,
48-
calculate_min_periods,
49-
get_weighted_roll_func,
5047
zsqrt,
5148
)
5249
from pandas.core.window.indexers import (
@@ -59,6 +56,81 @@
5956
from pandas.tseries.offsets import DateOffset
6057

6158

59+
def calculate_center_offset(window) -> int:
60+
"""
61+
Calculate an offset necessary to have the window label to be centered.
62+
63+
Parameters
64+
----------
65+
window: ndarray or int
66+
window weights or window
67+
68+
Returns
69+
-------
70+
int
71+
"""
72+
if not is_integer(window):
73+
window = len(window)
74+
return int((window - 1) / 2.0)
75+
76+
77+
def calculate_min_periods(
78+
window: int,
79+
min_periods: Optional[int],
80+
num_values: int,
81+
required_min_periods: int,
82+
floor: int,
83+
) -> int:
84+
"""
85+
Calculate final minimum periods value for rolling aggregations.
86+
87+
Parameters
88+
----------
89+
window : passed window value
90+
min_periods : passed min periods value
91+
num_values : total number of values
92+
required_min_periods : required min periods per aggregation function
93+
floor : required min periods per aggregation function
94+
95+
Returns
96+
-------
97+
min_periods : int
98+
"""
99+
if min_periods is None:
100+
min_periods = window
101+
else:
102+
min_periods = max(required_min_periods, min_periods)
103+
if min_periods > window:
104+
raise ValueError(f"min_periods {min_periods} must be <= window {window}")
105+
elif min_periods > num_values:
106+
min_periods = num_values + 1
107+
elif min_periods < 0:
108+
raise ValueError("min_periods must be >= 0")
109+
return max(min_periods, floor)
110+
111+
112+
def get_weighted_roll_func(cfunc: Callable) -> Callable:
113+
"""
114+
Wrap weighted rolling cython function with min periods argument.
115+
116+
Parameters
117+
----------
118+
cfunc : function
119+
Cython weighted rolling function
120+
121+
Returns
122+
-------
123+
function
124+
"""
125+
126+
def func(arg, window, min_periods=None):
127+
if min_periods is None:
128+
min_periods = len(window)
129+
return cfunc(arg, window, min_periods)
130+
131+
return func
132+
133+
62134
class _Window(PandasObject, ShallowMixin, SelectionMixin):
63135
_attributes: List[str] = [
64136
"window",

0 commit comments

Comments
 (0)