Skip to content

Commit 85e263a

Browse files
committed
tag the models with likelihoods
1 parent b2391ec commit 85e263a

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

conjugate/models.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@
4040
import warnings
4141

4242
from conjugate.distributions import (
43+
Multinomial,
44+
Poisson,
45+
Exponential,
46+
Binomial,
47+
Uniform,
48+
VonMises,
49+
LogNormal,
50+
Weibull,
51+
Bernoulli,
4352
Beta,
53+
Geometric,
4454
BetaBinomial,
4555
BetaGeometric,
4656
BetaNegativeBinomial,
@@ -51,6 +61,7 @@
5161
Gamma,
5262
GammaKnownRateProportional,
5363
GammaProportional,
64+
Hypergeometric,
5465
InverseGamma,
5566
InverseWishart,
5667
Lomax,
@@ -70,6 +81,15 @@
7081
from conjugate._typing import NUMERIC
7182

7283

84+
def add_associated_likelihood(name):
85+
def decorator(func: Callable) -> Callable:
86+
"""Decorator to add an associated distribution to the function."""
87+
setattr(func, "associated_likelihood", name)
88+
return func
89+
90+
return decorator
91+
92+
7393
def validate_type(func, parameter: str):
7494
expected_type = func.__annotations__.get(parameter, None)
7595

@@ -110,6 +130,7 @@ def get_binomial_beta_posterior_params(
110130
return alpha_post, beta_post
111131

112132

133+
@add_associated_likelihood(Binomial)
113134
@validate_prior_type
114135
def binomial_beta(*, n: NUMERIC, x: NUMERIC, prior: Beta) -> Beta:
115136
"""Posterior distribution for a binomial likelihood with a beta prior.
@@ -161,6 +182,7 @@ def binomial_beta(*, n: NUMERIC, x: NUMERIC, prior: Beta) -> Beta:
161182
return Beta(alpha=alpha_post, beta=beta_post)
162183

163184

185+
@add_associated_likelihood(Binomial)
164186
@validate_distribution_type
165187
def binomial_beta_predictive(*, n: NUMERIC, distribution: Beta) -> BetaBinomial:
166188
"""Posterior predictive distribution for a binomial likelihood with a beta prior.
@@ -209,6 +231,7 @@ def binomial_beta_predictive(*, n: NUMERIC, distribution: Beta) -> BetaBinomial:
209231
return BetaBinomial(n=n, alpha=distribution.alpha, beta=distribution.beta)
210232

211233

234+
@add_associated_likelihood(Bernoulli)
212235
@validate_prior_type
213236
def bernoulli_beta(*, x: NUMERIC, prior: Beta) -> Beta:
214237
"""Posterior distribution for a bernoulli likelihood with a beta prior.
@@ -241,6 +264,7 @@ def bernoulli_beta(*, x: NUMERIC, prior: Beta) -> Beta:
241264
return binomial_beta(n=1, x=x, prior=prior)
242265

243266

267+
@add_associated_likelihood(Bernoulli)
244268
@validate_distribution_type
245269
def bernoulli_beta_predictive(*, distribution: Beta) -> BetaBinomial:
246270
"""Predictive distribution for a bernoulli likelihood with a beta prior.
@@ -257,6 +281,7 @@ def bernoulli_beta_predictive(*, distribution: Beta) -> BetaBinomial:
257281
return binomial_beta_predictive(n=1, distribution=distribution)
258282

259283

284+
@add_associated_likelihood(NegativeBinomial)
260285
@validate_prior_type
261286
def negative_binomial_beta(*, r: NUMERIC, n: NUMERIC, x: NUMERIC, prior: Beta) -> Beta:
262287
"""Posterior distribution for a negative binomial likelihood with a beta prior.
@@ -279,6 +304,7 @@ def negative_binomial_beta(*, r: NUMERIC, n: NUMERIC, x: NUMERIC, prior: Beta) -
279304
return Beta(alpha=alpha_post, beta=beta_post)
280305

281306

307+
@add_associated_likelihood(NegativeBinomial)
282308
@validate_distribution_type
283309
def negative_binomial_beta_predictive(
284310
*,
@@ -300,6 +326,7 @@ def negative_binomial_beta_predictive(
300326
return BetaNegativeBinomial(n=r, alpha=distribution.alpha, beta=distribution.beta)
301327

302328

329+
@add_associated_likelihood(Hypergeometric)
303330
@validate_prior_type
304331
def hypergeometric_beta_binomial(
305332
*,
@@ -329,6 +356,7 @@ def hypergeometric_beta_binomial(
329356
return BetaBinomial(n=n, alpha=alpha_post, beta=beta_post)
330357

331358

359+
@add_associated_likelihood(Geometric)
332360
@validate_prior_type
333361
def geometric_beta(*, x_total, n, prior: Beta, one_start: bool = True) -> Beta:
334362
"""Posterior distribution for a geometric likelihood with a beta prior.
@@ -383,6 +411,7 @@ def geometric_beta(*, x_total, n, prior: Beta, one_start: bool = True) -> Beta:
383411
return Beta(alpha=alpha_post, beta=beta_post)
384412

385413

414+
@add_associated_likelihood(Geometric)
386415
@validate_distribution_type
387416
def geometric_beta_predictive(
388417
*,
@@ -469,6 +498,7 @@ def get_multi_categorical_dirichlet_posterior_params(
469498
return get_dirichlet_posterior_params(alpha_prior, x)
470499

471500

501+
@add_associated_likelihood(Multinomial)
472502
@validate_prior_type
473503
def multinomial_dirichlet(*, x: NUMERIC, prior: Dirichlet) -> Dirichlet:
474504
"""Posterior distribution of Multinomial model with Dirichlet prior.
@@ -522,6 +552,7 @@ def multinomial_dirichlet(*, x: NUMERIC, prior: Dirichlet) -> Dirichlet:
522552
return Dirichlet(alpha=alpha_post)
523553

524554

555+
@add_associated_likelihood(Multinomial)
525556
@validate_distribution_type
526557
def multinomial_dirichlet_predictive(
527558
*,
@@ -555,6 +586,7 @@ def get_poisson_gamma_posterior_params(
555586
return alpha_post, beta_post
556587

557588

589+
@add_associated_likelihood(Poisson)
558590
@validate_prior_type
559591
def poisson_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
560592
"""Posterior distribution for a poisson likelihood with a gamma prior.
@@ -575,6 +607,7 @@ def poisson_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
575607
return Gamma(alpha=alpha_post, beta=beta_post)
576608

577609

610+
@add_associated_likelihood(Poisson)
578611
@validate_distribution_type
579612
def poisson_gamma_predictive(
580613
*,
@@ -602,6 +635,7 @@ def poisson_gamma_predictive(
602635
get_exponential_gamma_posterior_params = get_poisson_gamma_posterior_params
603636

604637

638+
@add_associated_likelihood(Exponential)
605639
@validate_prior_type
606640
def exponential_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
607641
"""Posterior distribution for an exponential likelihood with a gamma prior.
@@ -622,6 +656,7 @@ def exponential_gamma(*, x_total: NUMERIC, n: NUMERIC, prior: Gamma) -> Gamma:
622656
return Gamma(alpha=alpha_post, beta=beta_post)
623657

624658

659+
@add_associated_likelihood(Exponential)
625660
@validate_distribution_type
626661
def exponential_gamma_predictive(*, distribution: Gamma) -> Lomax:
627662
"""Predictive distribution for an exponential likelihood with a gamma distribution
@@ -672,6 +707,7 @@ def exponential_gamma_predictive(*, distribution: Gamma) -> Lomax:
672707
return Lomax(alpha=distribution.beta, lam=distribution.alpha)
673708

674709

710+
@add_associated_likelihood(Gamma)
675711
@validate_prior_type
676712
def gamma_known_shape(
677713
*,
@@ -741,6 +777,7 @@ def gamma_known_shape(
741777
return Gamma(alpha=alpha_post, beta=beta_post)
742778

743779

780+
@add_associated_likelihood(Gamma)
744781
@validate_distribution_type
745782
def gamma_known_shape_predictive(
746783
*,
@@ -760,6 +797,7 @@ def gamma_known_shape_predictive(
760797
return CompoundGamma(alpha=alpha, beta=distribution.alpha, lam=distribution.beta)
761798

762799

800+
@add_associated_likelihood(InverseGamma)
763801
@validate_prior_type
764802
def inverse_gamma_known_rate(
765803
*,
@@ -786,6 +824,7 @@ def inverse_gamma_known_rate(
786824
return Gamma(alpha=alpha_post, beta=beta_post)
787825

788826

827+
@add_associated_likelihood(Normal)
789828
@validate_prior_type
790829
def normal_known_variance(
791830
*,
@@ -856,6 +895,7 @@ def normal_known_variance(
856895
return Normal(mu=mu_post, sigma=var_post**0.5)
857896

858897

898+
@add_associated_likelihood(Normal)
859899
@validate_distribution_type
860900
def normal_known_variance_predictive(*, var: NUMERIC, distribution: Normal) -> Normal:
861901
"""Predictive distribution for a normal likelihood with known variance and a normal distribution on mean.
@@ -921,6 +961,7 @@ def normal_known_variance_predictive(*, var: NUMERIC, distribution: Normal) -> N
921961
return Normal(mu=distribution.mu, sigma=var_posterior_predictive**0.5)
922962

923963

964+
@add_associated_likelihood(Normal)
924965
@validate_prior_type
925966
def normal_known_precision(
926967
*,
@@ -987,6 +1028,7 @@ def normal_known_precision(
9871028
)
9881029

9891030

1031+
@add_associated_likelihood(Normal)
9901032
@validate_distribution_type
9911033
def normal_known_precision_predictive(
9921034
*,
@@ -1072,6 +1114,7 @@ def _normal_known_mean_inverse_gamma_prior(
10721114
return InverseGamma(alpha=alpha_post, beta=beta_post)
10731115

10741116

1117+
@add_associated_likelihood(Normal)
10751118
@validate_prior_type
10761119
def normal_known_mean(
10771120
*,
@@ -1111,6 +1154,7 @@ def normal_known_mean(
11111154
return posterior
11121155

11131156

1157+
@add_associated_likelihood(Normal)
11141158
@validate_distribution_type
11151159
def normal_known_mean_predictive(
11161160
*,
@@ -1209,6 +1253,7 @@ def _normal(
12091253
return mu_post, nu_post, alpha_post, beta_post
12101254

12111255

1256+
@add_associated_likelihood(Normal)
12121257
@validate_prior_type
12131258
def normal(
12141259
*,
@@ -1261,6 +1306,7 @@ def normal(
12611306
return prior.__class__(**kwargs)
12621307

12631308

1309+
@add_associated_likelihood(Normal)
12641310
@validate_prior_type
12651311
def normal_normal_inverse_gamma(
12661312
*,
@@ -1292,6 +1338,7 @@ def normal_normal_inverse_gamma(
12921338
return normal(x_total=x_total, x2_total=x2_total, n=n, prior=prior) # type: ignore
12931339

12941340

1341+
@add_associated_likelihood(Normal)
12951342
@validate_distribution_type
12961343
def normal_predictive(
12971344
*,
@@ -1319,6 +1366,7 @@ def normal_predictive(
13191366
)
13201367

13211368

1369+
@add_associated_likelihood(Normal)
13221370
@validate_distribution_type
13231371
def normal_normal_inverse_gamma_predictive(
13241372
*,
@@ -1426,6 +1474,7 @@ def linear_regression_predictive(
14261474
)
14271475

14281476

1477+
@add_associated_likelihood(Uniform)
14291478
@validate_prior_type
14301479
def uniform_pareto(
14311480
*,
@@ -1470,6 +1519,7 @@ def uniform_pareto(
14701519
return Pareto(x_m=x_m_post, alpha=alpha_post)
14711520

14721521

1522+
@add_associated_likelihood(Pareto)
14731523
@validate_prior_type
14741524
def pareto_gamma(
14751525
*,
@@ -1539,6 +1589,7 @@ def pareto_gamma(
15391589
return Gamma(alpha=alpha_post, beta=beta_post)
15401590

15411591

1592+
@add_associated_likelihood(Gamma)
15421593
@validate_prior_type
15431594
def gamma(
15441595
*,
@@ -1569,6 +1620,7 @@ def gamma(
15691620
return GammaProportional(p=p_post, q=q_post, r=r_post, s=s_post)
15701621

15711622

1623+
@add_associated_likelihood(Gamma)
15721624
@validate_prior_type
15731625
def gamma_known_rate(
15741626
*,
@@ -1597,6 +1649,7 @@ def gamma_known_rate(
15971649
return GammaKnownRateProportional(a=a_post, b=b_post, c=c_post)
15981650

15991651

1652+
@add_associated_likelihood(Beta)
16001653
@validate_prior_type
16011654
def beta(
16021655
*,
@@ -1626,6 +1679,7 @@ def beta(
16261679
return BetaProportional(p=p_post, q=q_post, k=k_post)
16271680

16281681

1682+
@add_associated_likelihood(VonMises)
16291683
@validate_prior_type
16301684
def von_mises_known_concentration(
16311685
*,
@@ -1661,6 +1715,7 @@ def von_mises_known_concentration(
16611715
return VonMisesKnownConcentration(a=a_post, b=b_post)
16621716

16631717

1718+
@add_associated_likelihood(VonMises)
16641719
@validate_prior_type
16651720
def von_mises_known_direction(
16661721
*,
@@ -1701,6 +1756,7 @@ def _multivariate_normal_known_precision(
17011756
return mu_post, precision_post
17021757

17031758

1759+
@add_associated_likelihood(MultivariateNormal)
17041760
@validate_prior_type
17051761
def multivariate_normal_known_covariance(
17061762
*,
@@ -1739,6 +1795,7 @@ def multivariate_normal_known_covariance(
17391795
return MultivariateNormal(mu=mu_post, cov=inv(precision_post))
17401796

17411797

1798+
@add_associated_likelihood(MultivariateNormal)
17421799
@validate_distribution_type
17431800
def multivariate_normal_known_covariance_predictive(
17441801
*,
@@ -1760,6 +1817,7 @@ def multivariate_normal_known_covariance_predictive(
17601817
return MultivariateNormal(mu=mu_pred, cov=cov_pred)
17611818

17621819

1820+
@add_associated_likelihood(MultivariateNormal)
17631821
@validate_prior_type
17641822
def multivariate_normal_known_precision(
17651823
*,
@@ -1796,6 +1854,7 @@ def multivariate_normal_known_precision(
17961854
return MultivariateNormal(mu=mu_post, cov=inv(precision_post))
17971855

17981856

1857+
@add_associated_likelihood(MultivariateNormal)
17991858
@validate_distribution_type
18001859
def multivariate_normal_known_precision_predictive(
18011860
*,
@@ -1819,6 +1878,7 @@ def multivariate_normal_known_precision_predictive(
18191878
return MultivariateNormal(mu=mu_pred, cov=cov_pred)
18201879

18211880

1881+
@add_associated_likelihood(MultivariateNormal)
18221882
@validate_prior_type
18231883
def multivariate_normal_known_mean(
18241884
*,
@@ -1846,6 +1906,7 @@ def multivariate_normal_known_mean(
18461906
)
18471907

18481908

1909+
@add_associated_likelihood(MultivariateNormal)
18491910
@validate_prior_type
18501911
def multivariate_normal(
18511912
*,
@@ -1934,6 +1995,7 @@ def multivariate_normal(
19341995
)
19351996

19361997

1998+
@add_associated_likelihood(MultivariateNormal)
19371999
@validate_distribution_type
19382000
def multivariate_normal_predictive(
19392001
*,
@@ -2045,6 +2107,7 @@ def multivariate_normal_predictive(
20452107
return MultivariateStudentT(mu=mu, sigma=sigma, nu=nu)
20462108

20472109

2110+
@add_associated_likelihood(LogNormal)
20482111
@validate_prior_type
20492112
def log_normal(
20502113
*,
@@ -2120,6 +2183,7 @@ def log_normal(
21202183
)
21212184

21222185

2186+
@add_associated_likelihood(Weibull)
21232187
@validate_prior_type
21242188
def weibull_inverse_gamma_known_shape(
21252189
*,

0 commit comments

Comments
 (0)