33import numpy as np
44
55import matplotlib .pyplot as plt
6+ from matplotlib .figure import Figure
67
78from conjugate .distributions import (
89 Beta ,
2324
2425
2526@pytest .mark .mpl_image_compare
26- def test_label () -> None :
27+ def test_label () -> Figure :
2728 fig , ax = plt .subplots (figsize = FIGSIZE )
2829 beta = Beta (1 , 1 )
2930 beta .plot_pdf (ax = ax , label = "Uniform" )
@@ -32,7 +33,7 @@ def test_label() -> None:
3233
3334
3435@pytest .mark .mpl_image_compare
35- def test_multiple_labels_str () -> None :
36+ def test_multiple_labels_str () -> Figure :
3637 fig , ax = plt .subplots (figsize = FIGSIZE )
3738 beta = Beta (np .array ([1 , 2 , 3 ]), np .array ([1 , 2 , 3 ]))
3839 beta .plot_pdf (label = "Beta" , ax = ax )
@@ -41,7 +42,7 @@ def test_multiple_labels_str() -> None:
4142
4243
4344@pytest .mark .mpl_image_compare
44- def test_multiple_with_labels () -> None :
45+ def test_multiple_with_labels () -> Figure :
4546 fig , ax = plt .subplots (figsize = FIGSIZE )
4647 beta = Beta (np .array ([1 , 2 , 3 ]), np .array ([1 , 2 , 3 ]))
4748 ax = beta .plot_pdf (label = ["First Beta" , "Second Beta" , "Third Beta" ], ax = ax )
@@ -50,7 +51,7 @@ def test_multiple_with_labels() -> None:
5051
5152
5253@pytest .mark .mpl_image_compare
53- def test_skip_label () -> None :
54+ def test_skip_label () -> Figure :
5455 fig , ax = plt .subplots (figsize = FIGSIZE )
5556 beta = Beta (np .array ([1 , 2 , 3 ]), np .array ([1 , 2 , 3 ]))
5657 ax = beta .plot_pdf (label = ["First Beta" , None , "Third Beta" ], ax = ax )
@@ -59,7 +60,7 @@ def test_skip_label() -> None:
5960
6061
6162@pytest .mark .mpl_image_compare
62- def test_different_distributions () -> None :
63+ def test_different_distributions () -> Figure :
6364 fig , ax = plt .subplots (figsize = FIGSIZE )
6465 beta = Beta (1 , np .array ([1 , 2 ]))
6566 gamma = Gamma (1 , 1 )
@@ -74,7 +75,7 @@ def test_different_distributions() -> None:
7475
7576
7677@pytest .mark .mpl_image_compare
77- def test_analysis () -> None :
78+ def test_analysis () -> Figure :
7879 prior = Beta (1 , 1 )
7980
8081 N = 10
@@ -108,15 +109,15 @@ def test_analysis() -> None:
108109
109110
110111@pytest .mark .mpl_image_compare
111- def test_dirichlet () -> None :
112+ def test_dirichlet () -> Figure :
112113 fig , ax = plt .subplots (figsize = FIGSIZE )
113114 dirichlet = Dirichlet (np .array ([1 , 2 , 3 ]))
114115 ax = dirichlet .plot_pdf (random_state = 0 , ax = ax )
115116 return fig
116117
117118
118119@pytest .mark .mpl_image_compare
119- def test_dirichlet_labels () -> None :
120+ def test_dirichlet_labels () -> Figure :
120121 fig , ax = plt .subplots (figsize = FIGSIZE )
121122 dirichlet = Dirichlet (np .array ([1 , 2 , 3 ]))
122123 ax = dirichlet .plot_pdf (random_state = 0 , label = "Category" , ax = ax )
@@ -125,7 +126,7 @@ def test_dirichlet_labels() -> None:
125126
126127
127128@pytest .mark .mpl_image_compare
128- def test_dirichlet_multiple_labels () -> None :
129+ def test_dirichlet_multiple_labels () -> Figure :
129130 fig , ax = plt .subplots (figsize = FIGSIZE )
130131 dirichlet = Dirichlet (np .array ([1 , 2 , 3 ]))
131132 ax = dirichlet .plot_pdf (
@@ -138,7 +139,7 @@ def test_dirichlet_multiple_labels() -> None:
138139
139140
140141@pytest .mark .mpl_image_compare
141- def test_bayesian_update_example () -> None :
142+ def test_bayesian_update_example () -> Figure :
142143 def create_sampler (mu , sigma , rng ):
143144 def sample (n : int ):
144145 return rng .normal (loc = mu , scale = sigma , size = n )
@@ -196,7 +197,7 @@ def sample(n: int):
196197
197198
198199@pytest .mark .mpl_image_compare
199- def test_polar_plot () -> None :
200+ def test_polar_plot () -> Figure :
200201 kappas = np .array ([0.5 , 1 , 5 , 10 ])
201202 dist = VonMises (0 , kappa = kappas )
202203
@@ -210,7 +211,7 @@ def test_polar_plot() -> None:
210211
211212
212213@pytest .mark .mpl_image_compare
213- def test_cdf_continuous () -> None :
214+ def test_cdf_continuous () -> Figure :
214215 dist = Normal (0 , 1 )
215216 dist .set_bounds (- 5 , 5 )
216217 fig , ax = plt .subplots (figsize = FIGSIZE )
@@ -220,7 +221,7 @@ def test_cdf_continuous() -> None:
220221
221222
222223@pytest .mark .mpl_image_compare
223- def test_cdf_discrete () -> None :
224+ def test_cdf_discrete () -> Figure :
224225 dist = Binomial (n = 10 , p = 0.25 )
225226 fig , ax = plt .subplots (figsize = FIGSIZE )
226227
@@ -229,7 +230,7 @@ def test_cdf_discrete() -> None:
229230
230231
231232@pytest .mark .mpl_image_compare
232- def test_conditional_plot () -> None :
233+ def test_conditional_plot () -> Figure :
233234 dist = Binomial (n = 10 , p = 0.25 )
234235 dist .set_bounds (3 , 7 )
235236
@@ -239,7 +240,7 @@ def test_conditional_plot() -> None:
239240
240241
241242@pytest .mark .mpl_image_compare
242- def test_conditional_plot_cdf () -> None :
243+ def test_conditional_plot_cdf () -> Figure :
243244 dist = Binomial (n = 10 , p = 0.25 )
244245 dist .set_bounds (3 , 7 )
245246
@@ -249,24 +250,36 @@ def test_conditional_plot_cdf() -> None:
249250
250251
251252@pytest .mark .mpl_image_compare
252- def test_color_cycle_continuous () -> None :
253+ def test_color_cycle_continuous () -> Figure :
253254 dist = Normal (mu = 0 , sigma = [1 , 2 , 3 ]).set_bounds (- 10 , 10 )
254255 fig , ax = plt .subplots (figsize = FIGSIZE )
255256 dist .plot_pdf (ax = ax , color = ["red" , "green" , "teal" ])
256257 return fig
257258
258259
259260@pytest .mark .mpl_image_compare
260- def test_color_cycle_discrete () -> None :
261+ def test_color_cycle_discrete () -> Figure :
261262 dist = Binomial (n = 10 , p = [0.15 , 0.5 , 0.85 ])
262263 fig , ax = plt .subplots (figsize = FIGSIZE )
263264 dist .plot_pmf (ax = ax , color = ["red" , "green" , "teal" ])
264265 return fig
265266
266267
267268@pytest .mark .mpl_image_compare
268- def test_large_discrete_x_axis () -> None :
269+ def test_large_discrete_x_axis () -> Figure :
269270 dist = Binomial (n = 50 , p = 0.5 )
270271 fig , ax = plt .subplots (figsize = FIGSIZE )
271272 dist .plot_pmf (ax = ax )
272273 return fig
274+
275+
276+ @pytest .mark .mpl_image_compare
277+ def test_rgba_colors () -> Figure :
278+ prior = Beta (1 , 1 )
279+ posterior = Beta (10 , 10 )
280+
281+ fig , ax = plt .subplots (figsize = FIGSIZE )
282+ prior .plot_pdf (ax = ax , color = (1 , 0 , 0 , 0.5 ), label = "Prior" )
283+ posterior .plot_pdf (ax = ax , color = (0 , 1 , 0 , 0.5 ), label = "Posterior" )
284+ ax .legend (title = "Distribution" )
285+ return fig
0 commit comments