Skip to content

Commit b0ed4cf

Browse files
refactor new trendlines
1 parent 28f8e1f commit b0ed4cf

File tree

3 files changed

+56
-29
lines changed

3 files changed

+56
-29
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
44
from ._special_inputs import IdentityMap, Constant, Range
5-
from .trendline_functions import ols, lowess, ma, ewma
5+
from .trendline_functions import ols, lowess, rolling, expanding, ewm
66

77
from _plotly_utils.basevalidators import ColorscaleValidator
88
from plotly.colors import qualitative, sequential
@@ -17,7 +17,9 @@
1717
)
1818

1919
NO_COLOR = "px_no_color_constant"
20-
trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols)
20+
trendline_functions = dict(
21+
lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols
22+
)
2123

2224
# Declare all supported attributes, across all plot types
2325
direct_attrables = (

Diff for: packages/python/plotly/plotly/express/trendline_functions/__init__.py

+39-18
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
6666
fit_results.params[0],
6767
)
6868
elif not add_constant:
69-
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label,)
69+
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
7070
else:
71-
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
71+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
7272
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
7373
return y_out, hover_header, fit_results
7474

@@ -91,27 +91,48 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
9191
return y_out, hover_header, None
9292

9393

94-
def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
95-
"""Moving Average (MA) trendline function
94+
def _pandas(mode, trendline_options, x_raw, y, non_missing):
95+
modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
96+
function_name = trendline_options.pop("function", "mean")
97+
function_args = trendline_options.pop("function_args", dict())
98+
series = pd.Series(y, index=x_raw)
99+
agg = getattr(series, mode) # e.g. series.rolling
100+
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
101+
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
102+
y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
103+
y_out = y_out[non_missing]
104+
hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
105+
return y_out, hover_header, None
106+
96107

97-
Requires `pandas` to be installed.
108+
def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
109+
"""Rolling trendline function
98110
99-
The `trendline_options` dict is passed as keyword arguments into the
100-
`pandas.Series.rolling` function.
111+
The value of the `function` key of the `trendline_options` dict is the function to
112+
use (defaults to `mean`) and the value of the `function_args` key are taken to be
113+
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
114+
keyword arguments into the `pandas.Series.rolling` function.
101115
"""
102-
y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing]
103-
hover_header = "<b>MA trendline</b><br><br>"
104-
return y_out, hover_header, None
116+
return _pandas("rolling", trendline_options, x_raw, y, non_missing)
105117

106118

107-
def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
108-
"""Exponentially Weighted Moving Average (EWMA) trendline function
119+
def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
120+
"""Expanding trendline function
109121
110-
Requires `pandas` to be installed.
122+
The value of the `function` key of the `trendline_options` dict is the function to
123+
use (defaults to `mean`) and the value of the `function_args` key are taken to be
124+
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
125+
keyword arguments into the `pandas.Series.expanding` function.
126+
"""
127+
return _pandas("expanding", trendline_options, x_raw, y, non_missing)
111128

112-
The `trendline_options` dict is passed as keyword arguments into the
113-
`pandas.Series.ewma` function.
129+
130+
def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
131+
"""Exponentially weighted trendline function
132+
133+
The value of the `function` key of the `trendline_options` dict is the function to
134+
use (defaults to `mean`) and the value of the `function_args` key are taken to be
135+
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
136+
keyword arguments into the `pandas.Series.ewm` function.
114137
"""
115-
y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing]
116-
hover_header = "<b>EWMA trendline</b><br><br>"
117-
return y_out, hover_header, None
138+
return _pandas("ewm", trendline_options, x_raw, y, non_missing)

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
("ols", None),
1212
("lowess", None),
1313
("lowess", dict(frac=0.3)),
14-
("ma", dict(window=2)),
15-
("ewma", dict(alpha=0.5)),
14+
("rolling", dict(window=2)),
15+
("expanding", None),
16+
("ewm", dict(alpha=0.5)),
1617
],
1718
)
1819
def test_trendline_results_passthrough(mode, options):
@@ -48,8 +49,9 @@ def test_trendline_results_passthrough(mode, options):
4849
("ols", None),
4950
("lowess", None),
5051
("lowess", dict(frac=0.3)),
51-
("ma", dict(window=2)),
52-
("ewma", dict(alpha=0.5)),
52+
("rolling", dict(window=2)),
53+
("expanding", None),
54+
("ewm", dict(alpha=0.5)),
5355
],
5456
)
5557
def test_trendline_enough_values(mode, options):
@@ -102,8 +104,9 @@ def test_trendline_enough_values(mode, options):
102104
("ols", dict(add_constant=False, log_x=True, log_y=True)),
103105
("lowess", None),
104106
("lowess", dict(frac=0.3)),
105-
("ma", dict(window=2)),
106-
("ewma", dict(alpha=0.5)),
107+
("rolling", dict(window=2)),
108+
("expanding", None),
109+
("ewm", dict(alpha=0.5)),
107110
],
108111
)
109112
def test_trendline_nan_values(mode, options):
@@ -173,9 +176,10 @@ def test_ols_trendline_slopes():
173176
("ols", None),
174177
("lowess", None),
175178
("lowess", dict(frac=0.3)),
176-
("ma", dict(window=2)),
177-
("ma", dict(window="10d")),
178-
("ewma", dict(alpha=0.5)),
179+
("rolling", dict(window=2)),
180+
("rolling", dict(window="10d")),
181+
("expanding", None),
182+
("ewm", dict(alpha=0.5)),
179183
],
180184
)
181185
def test_trendline_on_timeseries(mode, options):

0 commit comments

Comments
 (0)