|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -"""PyMC3 Plotting. |
| 15 | +"""Alias for the `plots` submodule from ArviZ. |
16 | 16 |
|
17 | 17 | Plots are delegated to the ArviZ library, a general purpose library for
|
18 |
| -"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/ |
19 |
| -for details on plots. |
| 18 | +"exploratory analysis of Bayesian models." |
| 19 | +See https://arviz-devs.github.io/arviz/ for details on plots. |
20 | 20 | """
|
21 | 21 | import functools
|
22 | 22 | import sys
|
23 | 23 | import warnings
|
24 | 24 |
|
25 | 25 | import arviz as az
|
26 | 26 |
|
| 27 | +# Makes this module as identical to arviz.plots as possible |
| 28 | +for attr in dir(az.plots): |
| 29 | + obj = getattr(az.plots, attr) |
| 30 | + if not attr.startswith("__"): |
| 31 | + setattr(sys.modules[__name__], attr, obj) |
27 | 32 |
|
28 |
| -def map_args(func): |
29 |
| - swaps = [("varnames", "var_names")] |
30 | 33 |
|
| 34 | +def map_args(func, alias: str): |
31 | 35 | @functools.wraps(func)
|
32 | 36 | def wrapped(*args, **kwargs):
|
33 |
| - for (old, new) in swaps: |
34 |
| - if old in kwargs and new not in kwargs: |
35 |
| - warnings.warn( |
36 |
| - f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8" |
37 |
| - ) |
38 |
| - kwargs[new] = kwargs.pop(old) |
39 |
| - return func(*args, **kwargs) |
| 37 | + if "varnames" in kwargs: |
| 38 | + raise DeprecationWarning(f"The `varnames` kwarg was renamed to `var_names`.") |
| 39 | + original = func.__name__ |
| 40 | + warnings.warn( |
| 41 | + f"The function `{alias}` from PyMC3 is just an alias for `{original}` from ArviZ. " |
| 42 | + f"Please switch to `pymc3.{original}` or `arviz.{original}`.", |
| 43 | + DeprecationWarning, |
| 44 | + ) |
| 45 | + return func(*args, **kwargs) |
40 | 46 |
|
41 | 47 | return wrapped
|
42 | 48 |
|
43 | 49 |
|
44 |
| -# pymc3 custom plots: override these names for custom behavior |
45 |
| -autocorrplot = map_args(az.plot_autocorr) |
46 |
| -forestplot = map_args(az.plot_forest) |
47 |
| -kdeplot = map_args(az.plot_kde) |
48 |
| -plot_posterior = map_args(az.plot_posterior) |
49 |
| -energyplot = map_args(az.plot_energy) |
50 |
| -densityplot = map_args(az.plot_density) |
51 |
| -pairplot = map_args(az.plot_pair) |
52 |
| - |
53 |
| -# Use compact traceplot by default |
54 |
| -@map_args |
55 |
| -@functools.wraps(az.plot_trace) |
56 |
| -def traceplot(*args, **kwargs): |
57 |
| - try: |
58 |
| - kwargs.setdefault("compact", True) |
59 |
| - return az.plot_trace(*args, **kwargs) |
60 |
| - except TypeError: |
61 |
| - kwargs.pop("compact") |
62 |
| - return az.plot_trace(*args, **kwargs) |
| 50 | +# Aliases of ArviZ functions |
| 51 | +autocorrplot = map_args(az.plot_autocorr, alias="autocorrplot") |
| 52 | +forestplot = map_args(az.plot_forest, alias="forestplot") |
| 53 | +kdeplot = map_args(az.plot_kde, alias="kdeplot") |
| 54 | +plot_posterior = map_args(az.plot_posterior, alias="plot_posterior") |
| 55 | +energyplot = map_args(az.plot_energy, alias="energyplot") |
| 56 | +densityplot = map_args(az.plot_density, alias="densityplot") |
| 57 | +pairplot = map_args(az.plot_pair, alias="pairplot") |
| 58 | +traceplot = map_args(az.plot_trace, alias="traceplot") |
63 | 59 |
|
64 | 60 |
|
65 |
| -# addition arg mapping for compare plot |
| 61 | +# Customized with kwarg reformatting |
66 | 62 | @functools.wraps(az.plot_compare)
|
67 | 63 | def compareplot(*args, **kwargs):
|
| 64 | + warnings.warn( |
| 65 | + f"The function `compareplot` from PyMC3 is an alias for `plot_compare` from ArviZ. " |
| 66 | + "It also applies some kwarg replacements. Nevertheless, please switch " |
| 67 | + f"to `pymc3.plot_compare` or `arviz.plot_compare`.", |
| 68 | + DeprecationWarning, |
| 69 | + ) |
68 | 70 | if "comp_df" in kwargs:
|
69 | 71 | comp_df = kwargs["comp_df"].copy()
|
70 | 72 | else:
|
@@ -103,10 +105,6 @@ def compareplot(*args, **kwargs):
|
103 | 105 |
|
104 | 106 | from pymc3.plots.posteriorplot import plot_posterior_predictive_glm
|
105 | 107 |
|
106 |
| -# Access to arviz plots: base plots provided by arviz |
107 |
| -for plot in az.plots.__all__: |
108 |
| - setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot))) |
109 |
| - |
110 | 108 | __all__ = tuple(az.plots.__all__) + (
|
111 | 109 | "autocorrplot",
|
112 | 110 | "compareplot",
|
|
0 commit comments