@@ -138,11 +138,7 @@ def binomial_beta(*, n: NUMERIC, x: NUMERIC, prior: Beta) -> Beta:
138138
139139 prior = Beta(1, 1)
140140
141- posterior = binomial_beta(
142- n=impressions,
143- x=clicks,
144- prior=prior
145- )
141+ posterior = binomial_beta(n=impressions, x=clicks, prior=prior)
146142
147143 ax = plt.subplot(111)
148144 posterior.set_bounds(0, 0.5).plot_pdf(ax=ax, label=["A", "B"])
@@ -191,15 +187,8 @@ def binomial_beta_predictive(*, n: NUMERIC, distribution: Beta) -> BetaBinomial:
191187 clicks = np.array([10, 35])
192188
193189 prior = Beta(1, 1)
194- posterior = binomial_beta(
195- n=impressions,
196- x=clicks,
197- prior=prior
198- )
199- posterior_predictive = binomial_beta_predictive(
200- n=100,
201- distribution=posterior
202- )
190+ posterior = binomial_beta(n=impressions, x=clicks, prior=prior)
191+ posterior_predictive = binomial_beta_predictive(n=100, distribution=posterior)
203192
204193
205194 ax = plt.subplot(111)
@@ -242,10 +231,7 @@ def bernoulli_beta(*, x: NUMERIC, prior: Beta) -> Beta:
242231
243232 # Positive outcome
244233 x = 1
245- posterior = bernoulli_beta(
246- x=x,
247- prior=prior
248- )
234+ posterior = bernoulli_beta(x=x, prior=prior)
249235
250236 posterior.dist.ppf([0.025, 0.975])
251237 # array([0.15811388, 0.98742088])
@@ -372,11 +358,7 @@ def geometric_beta(*, x_total, n, prior: Beta, one_start: bool = True) -> Beta:
372358 data = np.array([3, 1, 1, 3, 2, 1])
373359
374360 prior = Beta(1, 1)
375- posterior = geometric_beta(
376- x_total=data.sum(),
377- n=data.size,
378- prior=prior
379- )
361+ posterior = geometric_beta(x_total=data.sum(), n=data.size, prior=prior)
380362
381363 ax = plt.subplot(111)
382364 posterior.set_bounds(0, 1).plot_pdf(ax=ax, label="posterior")
@@ -510,17 +492,16 @@ def multinomial_dirichlet(*, x: NUMERIC, prior: Dirichlet) -> Dirichlet:
510492 from conjugate.models import multinomial_dirichlet
511493
512494 kinds = ["chocolate", "vanilla", "strawberry"]
513- data = np.array([
514- [5, 2, 1],
515- [3, 1, 0],
516- [3, 2, 0],
517- ])
495+ data = np.array(
496+ [
497+ [5, 2, 1],
498+ [3, 1, 0],
499+ [3, 2, 0],
500+ ]
501+ )
518502
519503 prior = Dirichlet([1, 1, 1])
520- posterior = multinomial_dirichlet(
521- x=data.sum(axis=0),
522- prior=prior
523- )
504+ posterior = multinomial_dirichlet(x=data.sum(axis=0), prior=prior)
524505
525506 ax = plt.subplot(111)
526507 posterior.plot_pdf(ax=ax, label=kinds)
@@ -669,11 +650,7 @@ def exponential_gamma_predictive(*, distribution: Gamma) -> Lomax:
669650
670651 prior = Gamma(1, 1)
671652
672- posterior = exponential_gamma(
673- n=n_samples,
674- x_total=data.sum(),
675- prior=prior
676- )
653+ posterior = exponential_gamma(n=n_samples, x_total=data.sum(), prior=prior)
677654
678655 prior_predictive = exponential_gamma_predictive(distribution=prior)
679656 posterior_predictive = exponential_gamma_predictive(distribution=posterior)
@@ -911,10 +888,7 @@ def normal_known_variance_predictive(*, var: NUMERIC, distribution: Normal) -> N
911888 prior = Normal(0, 10)
912889
913890 posterior = normal_known_variance(
914- n=n_samples,
915- x_total=data.sum(),
916- var=known_var,
917- prior=prior
891+ n=n_samples, x_total=data.sum(), var=known_var, prior=prior
918892 )
919893
920894 prior_predictive = normal_known_variance_predictive(
@@ -929,7 +903,9 @@ def normal_known_variance_predictive(*, var: NUMERIC, distribution: Normal) -> N
929903 bound = 5
930904 ax = plt.subplot(111)
931905 true.set_bounds(-bound, bound).plot_pdf(ax=ax, label="true distribution")
932- posterior_predictive.set_bounds(-bound, bound).plot_pdf(ax=ax, label="posterior predictive")
906+ posterior_predictive.set_bounds(-bound, bound).plot_pdf(
907+ ax=ax, label="posterior predictive"
908+ )
933909 prior_predictive.set_bounds(-bound, bound).plot_pdf(ax=ax, label="prior predictive")
934910 ax.legend()
935911 ```
@@ -985,10 +961,7 @@ def normal_known_precision(
985961 prior = Normal(0, 10)
986962
987963 posterior = normal_known_precision(
988- n=n_samples,
989- x_total=data.sum(),
990- precision=known_precision,
991- prior=prior
964+ n=n_samples, x_total=data.sum(), precision=known_precision, prior=prior
992965 )
993966
994967 bound = 5
@@ -1050,10 +1023,7 @@ def normal_known_precision_predictive(
10501023 prior = Normal(0, 10)
10511024
10521025 posterior = normal_known_precision(
1053- n=n_samples,
1054- x_total=data.sum(),
1055- precision=known_precision,
1056- prior=prior
1026+ n=n_samples, x_total=data.sum(), precision=known_precision, prior=prior
10571027 )
10581028
10591029 prior_predictive = normal_known_precision_predictive(
@@ -1068,7 +1038,9 @@ def normal_known_precision_predictive(
10681038 bound = 5
10691039 ax = plt.subplot(111)
10701040 true.set_bounds(-bound, bound).plot_pdf(ax=ax, label="true distribution")
1071- posterior_predictive.set_bounds(-bound, bound).plot_pdf(ax=ax, label="posterior predictive")
1041+ posterior_predictive.set_bounds(-bound, bound).plot_pdf(
1042+ ax=ax, label="posterior predictive"
1043+ )
10721044 prior_predictive.set_bounds(-bound, bound).plot_pdf(ax=ax, label="prior predictive")
10731045 ax.legend()
10741046 ```
@@ -1179,7 +1151,7 @@ def normal_known_mean_predictive(
11791151 x_total=data.sum(),
11801152 x2_total=(data**2).sum(),
11811153 mu=known_mu,
1182- prior=prior
1154+ prior=prior,
11831155 )
11841156
11851157 bound = 5
@@ -1194,7 +1166,9 @@ def normal_known_mean_predictive(
11941166 mu=known_mu,
11951167 distribution=posterior,
11961168 )
1197- posterior_predictive.set_bounds(-bound, bound).plot_pdf(ax=ax, label="posterior predictive")
1169+ posterior_predictive.set_bounds(-bound, bound).plot_pdf(
1170+ ax=ax, label="posterior predictive"
1171+ )
11981172 ax.legend()
11991173 ```
12001174 <!--
@@ -1486,11 +1460,7 @@ def uniform_pareto(
14861460
14871461 prior = Pareto(1, 1)
14881462
1489- posterior = uniform_pareto(
1490- x_max=data.max(),
1491- n=n_samples,
1492- prior=prior
1493- )
1463+ posterior = uniform_pareto(x_max=data.max(), n=n_samples, prior=prior)
14941464 ```
14951465
14961466 """
@@ -1907,10 +1877,12 @@ def multivariate_normal(
19071877 from conjugate.models import multivariate_normal
19081878
19091879 true_mean = np.array([1, 5])
1910- true_cov = np.array([
1911- [1, 0.5],
1912- [0.5, 1],
1913- ])
1880+ true_cov = np.array(
1881+ [
1882+ [1, 0.5],
1883+ [0.5, 1],
1884+ ]
1885+ )
19141886
19151887 n_samples = 100
19161888 rng = np.random.default_rng(42)
@@ -1924,10 +1896,12 @@ def multivariate_normal(
19241896 mu=np.array([0, 0]),
19251897 kappa=1,
19261898 nu=3,
1927- psi=np.array([
1928- [1, 0],
1929- [0, 1],
1930- ]),
1899+ psi=np.array(
1900+ [
1901+ [1, 0],
1902+ [0, 1],
1903+ ]
1904+ ),
19311905 )
19321906
19331907 posterior = multivariate_normal(
@@ -1990,10 +1964,12 @@ def multivariate_normal_predictive(
19901964 sigma_2 = 1.5
19911965 rho = -0.65
19921966 true_mean = np.array([mu_1, mu_2])
1993- true_cov = np.array([
1994- [sigma_1 ** 2, rho * sigma_1 * sigma_2],
1995- [rho * sigma_1 * sigma_2, sigma_2 ** 2],
1996- ])
1967+ true_cov = np.array(
1968+ [
1969+ [sigma_1**2, rho * sigma_1 * sigma_2],
1970+ [rho * sigma_1 * sigma_2, sigma_2**2],
1971+ ]
1972+ )
19971973 true = MultivariateNormal(true_mean, true_cov)
19981974
19991975 n_samples = 100
@@ -2004,10 +1980,12 @@ def multivariate_normal_predictive(
20041980 mu=np.array([0, 0]),
20051981 kappa=1,
20061982 nu=2,
2007- psi=np.array([
2008- [5 ** 2, 0],
2009- [0, 5 ** 2],
2010- ]),
1983+ psi=np.array(
1984+ [
1985+ [5**2, 0],
1986+ [0, 5**2],
1987+ ]
1988+ ),
20111989 )
20121990
20131991 posterior = multivariate_normal(
@@ -2021,13 +1999,15 @@ def multivariate_normal_predictive(
20211999
20222000 xmax = mu_1 + 3 * sigma_1
20232001 ymax = mu_2 + 3 * sigma_2
2024- x, y = np.mgrid[-xmax:xmax:.1, -ymax:ymax:.1]
2002+ x, y = np.mgrid[-xmax:xmax:0 .1, -ymax:ymax:0 .1]
20252003 pos = np.dstack((x, y))
20262004 z = true.dist.pdf(pos)
20272005 # z = np.where(z < 0.005, np.nan, z)
20282006 contours = ax.contour(x, y, z, alpha=0.55, color="black")
20292007
2030- for label, dist in zip(["prior", "posterior"], [prior_predictive, posterior_predictive]):
2008+ for label, dist in zip(
2009+ ["prior", "posterior"], [prior_predictive, posterior_predictive]
2010+ ):
20312011 X = dist.dist.rvs(size=1000)
20322012 ax.scatter(X[:, 0], X[:, 1], alpha=0.15, label=f"{label} predictive")
20332013
@@ -2110,10 +2090,7 @@ def log_normal(
21102090
21112091 prior = NormalInverseGamma(mu=1, nu=1, alpha=1, beta=1)
21122092 posterior = log_normal_normal_inverse_gamma(
2113- ln_x_total=ln_data.sum(),
2114- ln_x2_total=(ln_data**2).sum(),
2115- n=n_samples,
2116- prior=prior
2093+ ln_x_total=ln_data.sum(), ln_x2_total=(ln_data**2).sum(), n=n_samples, prior=prior
21172094 )
21182095
21192096 fig, axes = plt.subplots(ncols=2)
0 commit comments