4040import warnings
4141
4242from conjugate .distributions import (
43+ Multinomial ,
44+ Poisson ,
45+ Exponential ,
46+ Binomial ,
47+ Uniform ,
48+ VonMises ,
49+ LogNormal ,
50+ Weibull ,
51+ Bernoulli ,
4352 Beta ,
53+ Geometric ,
4454 BetaBinomial ,
4555 BetaGeometric ,
4656 BetaNegativeBinomial ,
5161 Gamma ,
5262 GammaKnownRateProportional ,
5363 GammaProportional ,
64+ Hypergeometric ,
5465 InverseGamma ,
5566 InverseWishart ,
5667 Lomax ,
7081from conjugate ._typing import NUMERIC
7182
7283
84+ def add_associated_likelihood (name ):
85+ def decorator (func : Callable ) -> Callable :
86+ """Decorator to add an associated distribution to the function."""
87+ setattr (func , "associated_likelihood" , name )
88+ return func
89+
90+ return decorator
91+
92+
7393def validate_type (func , parameter : str ):
7494 expected_type = func .__annotations__ .get (parameter , None )
7595
@@ -110,6 +130,7 @@ def get_binomial_beta_posterior_params(
110130 return alpha_post , beta_post
111131
112132
133+ @add_associated_likelihood (Binomial )
113134@validate_prior_type
114135def binomial_beta (* , n : NUMERIC , x : NUMERIC , prior : Beta ) -> Beta :
115136 """Posterior distribution for a binomial likelihood with a beta prior.
@@ -161,6 +182,7 @@ def binomial_beta(*, n: NUMERIC, x: NUMERIC, prior: Beta) -> Beta:
161182 return Beta (alpha = alpha_post , beta = beta_post )
162183
163184
185+ @add_associated_likelihood (Binomial )
164186@validate_distribution_type
165187def binomial_beta_predictive (* , n : NUMERIC , distribution : Beta ) -> BetaBinomial :
166188 """Posterior predictive distribution for a binomial likelihood with a beta prior.
@@ -209,6 +231,7 @@ def binomial_beta_predictive(*, n: NUMERIC, distribution: Beta) -> BetaBinomial:
209231 return BetaBinomial (n = n , alpha = distribution .alpha , beta = distribution .beta )
210232
211233
234+ @add_associated_likelihood (Bernoulli )
212235@validate_prior_type
213236def bernoulli_beta (* , x : NUMERIC , prior : Beta ) -> Beta :
214237 """Posterior distribution for a bernoulli likelihood with a beta prior.
@@ -241,6 +264,7 @@ def bernoulli_beta(*, x: NUMERIC, prior: Beta) -> Beta:
241264 return binomial_beta (n = 1 , x = x , prior = prior )
242265
243266
267+ @add_associated_likelihood (Bernoulli )
244268@validate_distribution_type
245269def bernoulli_beta_predictive (* , distribution : Beta ) -> BetaBinomial :
246270 """Predictive distribution for a bernoulli likelihood with a beta prior.
@@ -257,6 +281,7 @@ def bernoulli_beta_predictive(*, distribution: Beta) -> BetaBinomial:
257281 return binomial_beta_predictive (n = 1 , distribution = distribution )
258282
259283
284+ @add_associated_likelihood (NegativeBinomial )
260285@validate_prior_type
261286def negative_binomial_beta (* , r : NUMERIC , n : NUMERIC , x : NUMERIC , prior : Beta ) -> Beta :
262287 """Posterior distribution for a negative binomial likelihood with a beta prior.
@@ -279,6 +304,7 @@ def negative_binomial_beta(*, r: NUMERIC, n: NUMERIC, x: NUMERIC, prior: Beta) -
279304 return Beta (alpha = alpha_post , beta = beta_post )
280305
281306
307+ @add_associated_likelihood (NegativeBinomial )
282308@validate_distribution_type
283309def negative_binomial_beta_predictive (
284310 * ,
@@ -300,6 +326,7 @@ def negative_binomial_beta_predictive(
300326 return BetaNegativeBinomial (n = r , alpha = distribution .alpha , beta = distribution .beta )
301327
302328
329+ @add_associated_likelihood (Hypergeometric )
303330@validate_prior_type
304331def hypergeometric_beta_binomial (
305332 * ,
@@ -329,6 +356,7 @@ def hypergeometric_beta_binomial(
329356 return BetaBinomial (n = n , alpha = alpha_post , beta = beta_post )
330357
331358
359+ @add_associated_likelihood (Geometric )
332360@validate_prior_type
333361def geometric_beta (* , x_total , n , prior : Beta , one_start : bool = True ) -> Beta :
334362 """Posterior distribution for a geometric likelihood with a beta prior.
@@ -383,6 +411,7 @@ def geometric_beta(*, x_total, n, prior: Beta, one_start: bool = True) -> Beta:
383411 return Beta (alpha = alpha_post , beta = beta_post )
384412
385413
414+ @add_associated_likelihood (Geometric )
386415@validate_distribution_type
387416def geometric_beta_predictive (
388417 * ,
@@ -469,6 +498,7 @@ def get_multi_categorical_dirichlet_posterior_params(
469498 return get_dirichlet_posterior_params (alpha_prior , x )
470499
471500
501+ @add_associated_likelihood (Multinomial )
472502@validate_prior_type
473503def multinomial_dirichlet (* , x : NUMERIC , prior : Dirichlet ) -> Dirichlet :
474504 """Posterior distribution of Multinomial model with Dirichlet prior.
@@ -522,6 +552,7 @@ def multinomial_dirichlet(*, x: NUMERIC, prior: Dirichlet) -> Dirichlet:
522552 return Dirichlet (alpha = alpha_post )
523553
524554
555+ @add_associated_likelihood (Multinomial )
525556@validate_distribution_type
526557def multinomial_dirichlet_predictive (
527558 * ,
@@ -555,6 +586,7 @@ def get_poisson_gamma_posterior_params(
555586 return alpha_post , beta_post
556587
557588
589+ @add_associated_likelihood (Poisson )
558590@validate_prior_type
559591def poisson_gamma (* , x_total : NUMERIC , n : NUMERIC , prior : Gamma ) -> Gamma :
560592 """Posterior distribution for a poisson likelihood with a gamma prior.
@@ -575,6 +607,7 @@ def poisson_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
575607 return Gamma (alpha = alpha_post , beta = beta_post )
576608
577609
610+ @add_associated_likelihood (Poisson )
578611@validate_distribution_type
579612def poisson_gamma_predictive (
580613 * ,
@@ -602,6 +635,7 @@ def poisson_gamma_predictive(
602635get_exponential_gamma_posterior_params = get_poisson_gamma_posterior_params
603636
604637
638+ @add_associated_likelihood (Exponential )
605639@validate_prior_type
606640def exponential_gamma (* , x_total : NUMERIC , n : NUMERIC , prior : Gamma ) -> Gamma :
607641 """Posterior distribution for an exponential likelihood with a gamma prior.
@@ -622,6 +656,7 @@ def exponential_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
622656 return Gamma (alpha = alpha_post , beta = beta_post )
623657
624658
659+ @add_associated_likelihood (Exponential )
625660@validate_distribution_type
626661def exponential_gamma_predictive (* , distribution : Gamma ) -> Lomax :
627662 """Predictive distribution for an exponential likelihood with a gamma distribution
@@ -672,6 +707,7 @@ def exponential_gamma_predictive(*, distribution: Gamma) -> Lomax:
672707 return Lomax (alpha = distribution .beta , lam = distribution .alpha )
673708
674709
710+ @add_associated_likelihood (Gamma )
675711@validate_prior_type
676712def gamma_known_shape (
677713 * ,
@@ -741,6 +777,7 @@ def gamma_known_shape(
741777 return Gamma (alpha = alpha_post , beta = beta_post )
742778
743779
780+ @add_associated_likelihood (Gamma )
744781@validate_distribution_type
745782def gamma_known_shape_predictive (
746783 * ,
@@ -760,6 +797,7 @@ def gamma_known_shape_predictive(
760797 return CompoundGamma (alpha = alpha , beta = distribution .alpha , lam = distribution .beta )
761798
762799
800+ @add_associated_likelihood (InverseGamma )
763801@validate_prior_type
764802def inverse_gamma_known_rate (
765803 * ,
@@ -786,6 +824,7 @@ def inverse_gamma_known_rate(
786824 return Gamma (alpha = alpha_post , beta = beta_post )
787825
788826
827+ @add_associated_likelihood (Normal )
789828@validate_prior_type
790829def normal_known_variance (
791830 * ,
@@ -856,6 +895,7 @@ def normal_known_variance(
856895 return Normal (mu = mu_post , sigma = var_post ** 0.5 )
857896
858897
898+ @add_associated_likelihood (Normal )
859899@validate_distribution_type
860900def normal_known_variance_predictive (* , var : NUMERIC , distribution : Normal ) -> Normal :
861901 """Predictive distribution for a normal likelihood with known variance and a normal distribution on mean.
@@ -921,6 +961,7 @@ def normal_known_variance_predictive(*, var: NUMERIC, distribution: Normal) -> N
921961 return Normal (mu = distribution .mu , sigma = var_posterior_predictive ** 0.5 )
922962
923963
964+ @add_associated_likelihood (Normal )
924965@validate_prior_type
925966def normal_known_precision (
926967 * ,
@@ -987,6 +1028,7 @@ def normal_known_precision(
9871028 )
9881029
9891030
1031+ @add_associated_likelihood (Normal )
9901032@validate_distribution_type
9911033def normal_known_precision_predictive (
9921034 * ,
@@ -1072,6 +1114,7 @@ def _normal_known_mean_inverse_gamma_prior(
10721114 return InverseGamma (alpha = alpha_post , beta = beta_post )
10731115
10741116
1117+ @add_associated_likelihood (Normal )
10751118@validate_prior_type
10761119def normal_known_mean (
10771120 * ,
@@ -1111,6 +1154,7 @@ def normal_known_mean(
11111154 return posterior
11121155
11131156
1157+ @add_associated_likelihood (Normal )
11141158@validate_distribution_type
11151159def normal_known_mean_predictive (
11161160 * ,
@@ -1209,6 +1253,7 @@ def _normal(
12091253 return mu_post , nu_post , alpha_post , beta_post
12101254
12111255
1256+ @add_associated_likelihood (Normal )
12121257@validate_prior_type
12131258def normal (
12141259 * ,
@@ -1261,6 +1306,7 @@ def normal(
12611306 return prior .__class__ (** kwargs )
12621307
12631308
1309+ @add_associated_likelihood (Normal )
12641310@validate_prior_type
12651311def normal_normal_inverse_gamma (
12661312 * ,
@@ -1292,6 +1338,7 @@ def normal_normal_inverse_gamma(
12921338 return normal (x_total = x_total , x2_total = x2_total , n = n , prior = prior ) # type: ignore
12931339
12941340
1341+ @add_associated_likelihood (Normal )
12951342@validate_distribution_type
12961343def normal_predictive (
12971344 * ,
@@ -1319,6 +1366,7 @@ def normal_predictive(
13191366 )
13201367
13211368
1369+ @add_associated_likelihood (Normal )
13221370@validate_distribution_type
13231371def normal_normal_inverse_gamma_predictive (
13241372 * ,
@@ -1426,6 +1474,7 @@ def linear_regression_predictive(
14261474 )
14271475
14281476
1477+ @add_associated_likelihood (Uniform )
14291478@validate_prior_type
14301479def uniform_pareto (
14311480 * ,
@@ -1470,6 +1519,7 @@ def uniform_pareto(
14701519 return Pareto (x_m = x_m_post , alpha = alpha_post )
14711520
14721521
1522+ @add_associated_likelihood (Pareto )
14731523@validate_prior_type
14741524def pareto_gamma (
14751525 * ,
@@ -1539,6 +1589,7 @@ def pareto_gamma(
15391589 return Gamma (alpha = alpha_post , beta = beta_post )
15401590
15411591
1592+ @add_associated_likelihood (Gamma )
15421593@validate_prior_type
15431594def gamma (
15441595 * ,
@@ -1569,6 +1620,7 @@ def gamma(
15691620 return GammaProportional (p = p_post , q = q_post , r = r_post , s = s_post )
15701621
15711622
1623+ @add_associated_likelihood (Gamma )
15721624@validate_prior_type
15731625def gamma_known_rate (
15741626 * ,
@@ -1597,6 +1649,7 @@ def gamma_known_rate(
15971649 return GammaKnownRateProportional (a = a_post , b = b_post , c = c_post )
15981650
15991651
1652+ @add_associated_likelihood (Beta )
16001653@validate_prior_type
16011654def beta (
16021655 * ,
@@ -1626,6 +1679,7 @@ def beta(
16261679 return BetaProportional (p = p_post , q = q_post , k = k_post )
16271680
16281681
1682+ @add_associated_likelihood (VonMises )
16291683@validate_prior_type
16301684def von_mises_known_concentration (
16311685 * ,
@@ -1661,6 +1715,7 @@ def von_mises_known_concentration(
16611715 return VonMisesKnownConcentration (a = a_post , b = b_post )
16621716
16631717
1718+ @add_associated_likelihood (VonMises )
16641719@validate_prior_type
16651720def von_mises_known_direction (
16661721 * ,
@@ -1701,6 +1756,7 @@ def _multivariate_normal_known_precision(
17011756 return mu_post , precision_post
17021757
17031758
1759+ @add_associated_likelihood (MultivariateNormal )
17041760@validate_prior_type
17051761def multivariate_normal_known_covariance (
17061762 * ,
@@ -1739,6 +1795,7 @@ def multivariate_normal_known_covariance(
17391795 return MultivariateNormal (mu = mu_post , cov = inv (precision_post ))
17401796
17411797
1798+ @add_associated_likelihood (MultivariateNormal )
17421799@validate_distribution_type
17431800def multivariate_normal_known_covariance_predictive (
17441801 * ,
@@ -1760,6 +1817,7 @@ def multivariate_normal_known_covariance_predictive(
17601817 return MultivariateNormal (mu = mu_pred , cov = cov_pred )
17611818
17621819
1820+ @add_associated_likelihood (MultivariateNormal )
17631821@validate_prior_type
17641822def multivariate_normal_known_precision (
17651823 * ,
@@ -1796,6 +1854,7 @@ def multivariate_normal_known_precision(
17961854 return MultivariateNormal (mu = mu_post , cov = inv (precision_post ))
17971855
17981856
1857+ @add_associated_likelihood (MultivariateNormal )
17991858@validate_distribution_type
18001859def multivariate_normal_known_precision_predictive (
18011860 * ,
@@ -1819,6 +1878,7 @@ def multivariate_normal_known_precision_predictive(
18191878 return MultivariateNormal (mu = mu_pred , cov = cov_pred )
18201879
18211880
1881+ @add_associated_likelihood (MultivariateNormal )
18221882@validate_prior_type
18231883def multivariate_normal_known_mean (
18241884 * ,
@@ -1846,6 +1906,7 @@ def multivariate_normal_known_mean(
18461906 )
18471907
18481908
1909+ @add_associated_likelihood (MultivariateNormal )
18491910@validate_prior_type
18501911def multivariate_normal (
18511912 * ,
@@ -1934,6 +1995,7 @@ def multivariate_normal(
19341995 )
19351996
19361997
1998+ @add_associated_likelihood (MultivariateNormal )
19371999@validate_distribution_type
19382000def multivariate_normal_predictive (
19392001 * ,
@@ -2045,6 +2107,7 @@ def multivariate_normal_predictive(
20452107 return MultivariateStudentT (mu = mu , sigma = sigma , nu = nu )
20462108
20472109
2110+ @add_associated_likelihood (LogNormal )
20482111@validate_prior_type
20492112def log_normal (
20502113 * ,
@@ -2120,6 +2183,7 @@ def log_normal(
21202183 )
21212184
21222185
2186+ @add_associated_likelihood (Weibull )
21232187@validate_prior_type
21242188def weibull_inverse_gamma_known_shape (
21252189 * ,
0 commit comments