Skip to content

Commit 1704854

Browse files
committed
add example of rgba colors
1 parent 8332c5b commit 1704854

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed
28.1 KB
Loading

tests/test_example_plots.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import matplotlib.pyplot as plt
6+
from matplotlib.figure import Figure
67

78
from conjugate.distributions import (
89
Beta,
@@ -23,7 +24,7 @@
2324

2425

2526
@pytest.mark.mpl_image_compare
26-
def test_label() -> None:
27+
def test_label() -> Figure:
2728
fig, ax = plt.subplots(figsize=FIGSIZE)
2829
beta = Beta(1, 1)
2930
beta.plot_pdf(ax=ax, label="Uniform")
@@ -32,7 +33,7 @@ def test_label() -> None:
3233

3334

3435
@pytest.mark.mpl_image_compare
35-
def test_multiple_labels_str() -> None:
36+
def test_multiple_labels_str() -> Figure:
3637
fig, ax = plt.subplots(figsize=FIGSIZE)
3738
beta = Beta(np.array([1, 2, 3]), np.array([1, 2, 3]))
3839
beta.plot_pdf(label="Beta", ax=ax)
@@ -41,7 +42,7 @@ def test_multiple_labels_str() -> None:
4142

4243

4344
@pytest.mark.mpl_image_compare
44-
def test_multiple_with_labels() -> None:
45+
def test_multiple_with_labels() -> Figure:
4546
fig, ax = plt.subplots(figsize=FIGSIZE)
4647
beta = Beta(np.array([1, 2, 3]), np.array([1, 2, 3]))
4748
ax = beta.plot_pdf(label=["First Beta", "Second Beta", "Third Beta"], ax=ax)
@@ -50,7 +51,7 @@ def test_multiple_with_labels() -> None:
5051

5152

5253
@pytest.mark.mpl_image_compare
53-
def test_skip_label() -> None:
54+
def test_skip_label() -> Figure:
5455
fig, ax = plt.subplots(figsize=FIGSIZE)
5556
beta = Beta(np.array([1, 2, 3]), np.array([1, 2, 3]))
5657
ax = beta.plot_pdf(label=["First Beta", None, "Third Beta"], ax=ax)
@@ -59,7 +60,7 @@ def test_skip_label() -> None:
5960

6061

6162
@pytest.mark.mpl_image_compare
62-
def test_different_distributions() -> None:
63+
def test_different_distributions() -> Figure:
6364
fig, ax = plt.subplots(figsize=FIGSIZE)
6465
beta = Beta(1, np.array([1, 2]))
6566
gamma = Gamma(1, 1)
@@ -74,7 +75,7 @@ def test_different_distributions() -> None:
7475

7576

7677
@pytest.mark.mpl_image_compare
77-
def test_analysis() -> None:
78+
def test_analysis() -> Figure:
7879
prior = Beta(1, 1)
7980

8081
N = 10
@@ -108,15 +109,15 @@ def test_analysis() -> None:
108109

109110

110111
@pytest.mark.mpl_image_compare
111-
def test_dirichlet() -> None:
112+
def test_dirichlet() -> Figure:
112113
fig, ax = plt.subplots(figsize=FIGSIZE)
113114
dirichlet = Dirichlet(np.array([1, 2, 3]))
114115
ax = dirichlet.plot_pdf(random_state=0, ax=ax)
115116
return fig
116117

117118

118119
@pytest.mark.mpl_image_compare
119-
def test_dirichlet_labels() -> None:
120+
def test_dirichlet_labels() -> Figure:
120121
fig, ax = plt.subplots(figsize=FIGSIZE)
121122
dirichlet = Dirichlet(np.array([1, 2, 3]))
122123
ax = dirichlet.plot_pdf(random_state=0, label="Category", ax=ax)
@@ -125,7 +126,7 @@ def test_dirichlet_labels() -> None:
125126

126127

127128
@pytest.mark.mpl_image_compare
128-
def test_dirichlet_multiple_labels() -> None:
129+
def test_dirichlet_multiple_labels() -> Figure:
129130
fig, ax = plt.subplots(figsize=FIGSIZE)
130131
dirichlet = Dirichlet(np.array([1, 2, 3]))
131132
ax = dirichlet.plot_pdf(
@@ -138,7 +139,7 @@ def test_dirichlet_multiple_labels() -> None:
138139

139140

140141
@pytest.mark.mpl_image_compare
141-
def test_bayesian_update_example() -> None:
142+
def test_bayesian_update_example() -> Figure:
142143
def create_sampler(mu, sigma, rng):
143144
def sample(n: int):
144145
return rng.normal(loc=mu, scale=sigma, size=n)
@@ -196,7 +197,7 @@ def sample(n: int):
196197

197198

198199
@pytest.mark.mpl_image_compare
199-
def test_polar_plot() -> None:
200+
def test_polar_plot() -> Figure:
200201
kappas = np.array([0.5, 1, 5, 10])
201202
dist = VonMises(0, kappa=kappas)
202203

@@ -210,7 +211,7 @@ def test_polar_plot() -> None:
210211

211212

212213
@pytest.mark.mpl_image_compare
213-
def test_cdf_continuous() -> None:
214+
def test_cdf_continuous() -> Figure:
214215
dist = Normal(0, 1)
215216
dist.set_bounds(-5, 5)
216217
fig, ax = plt.subplots(figsize=FIGSIZE)
@@ -220,7 +221,7 @@ def test_cdf_continuous() -> None:
220221

221222

222223
@pytest.mark.mpl_image_compare
223-
def test_cdf_discrete() -> None:
224+
def test_cdf_discrete() -> Figure:
224225
dist = Binomial(n=10, p=0.25)
225226
fig, ax = plt.subplots(figsize=FIGSIZE)
226227

@@ -229,7 +230,7 @@ def test_cdf_discrete() -> None:
229230

230231

231232
@pytest.mark.mpl_image_compare
232-
def test_conditional_plot() -> None:
233+
def test_conditional_plot() -> Figure:
233234
dist = Binomial(n=10, p=0.25)
234235
dist.set_bounds(3, 7)
235236

@@ -239,7 +240,7 @@ def test_conditional_plot() -> None:
239240

240241

241242
@pytest.mark.mpl_image_compare
242-
def test_conditional_plot_cdf() -> None:
243+
def test_conditional_plot_cdf() -> Figure:
243244
dist = Binomial(n=10, p=0.25)
244245
dist.set_bounds(3, 7)
245246

@@ -249,24 +250,36 @@ def test_conditional_plot_cdf() -> None:
249250

250251

251252
@pytest.mark.mpl_image_compare
252-
def test_color_cycle_continuous() -> None:
253+
def test_color_cycle_continuous() -> Figure:
253254
dist = Normal(mu=0, sigma=[1, 2, 3]).set_bounds(-10, 10)
254255
fig, ax = plt.subplots(figsize=FIGSIZE)
255256
dist.plot_pdf(ax=ax, color=["red", "green", "teal"])
256257
return fig
257258

258259

259260
@pytest.mark.mpl_image_compare
260-
def test_color_cycle_discrete() -> None:
261+
def test_color_cycle_discrete() -> Figure:
261262
dist = Binomial(n=10, p=[0.15, 0.5, 0.85])
262263
fig, ax = plt.subplots(figsize=FIGSIZE)
263264
dist.plot_pmf(ax=ax, color=["red", "green", "teal"])
264265
return fig
265266

266267

267268
@pytest.mark.mpl_image_compare
268-
def test_large_discrete_x_axis() -> None:
269+
def test_large_discrete_x_axis() -> Figure:
269270
dist = Binomial(n=50, p=0.5)
270271
fig, ax = plt.subplots(figsize=FIGSIZE)
271272
dist.plot_pmf(ax=ax)
272273
return fig
274+
275+
276+
@pytest.mark.mpl_image_compare
277+
def test_rgba_colors() -> Figure:
278+
prior = Beta(1, 1)
279+
posterior = Beta(10, 10)
280+
281+
fig, ax = plt.subplots(figsize=FIGSIZE)
282+
prior.plot_pdf(ax=ax, color=(1, 0, 0, 0.5), label="Prior")
283+
posterior.plot_pdf(ax=ax, color=(0, 1, 0, 0.5), label="Posterior")
284+
ax.legend(title="Distribution")
285+
return fig

0 commit comments

Comments
 (0)