Skip to content

Commit 68ab171

Browse files
committed
type various distributions
1 parent 1d4bc50 commit 68ab171

1 file changed

Lines changed: 73 additions & 59 deletions

File tree

conjugate/distributions.py

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
from conjugate._compound_gamma import compound_gamma
4949
from conjugate._beta_geometric import beta_geometric
50-
from conjugate._typing import NUMERIC
50+
from conjugate._typing import NUMERIC, Real, PositiveReal, Natural, Probability
5151
from conjugate.plot import (
5252
DirichletPlotDistMixin,
5353
DiscretePlotMixin,
@@ -72,14 +72,14 @@ class Beta(ContinuousPlotDistMixin, SliceMixin):
7272
7373
"""
7474

75-
alpha: NUMERIC
76-
beta: NUMERIC
75+
alpha: PositiveReal
76+
beta: PositiveReal
7777

7878
def __post_init__(self) -> None:
7979
self.max_value = 1.0
8080

8181
@classmethod
82-
def from_mean(cls, mean: NUMERIC, alpha: NUMERIC) -> "Beta":
82+
def from_mean(cls, mean: Probability, alpha: PositiveReal) -> "Beta":
8383
"""Alternative constructor from mean and alpha."""
8484
beta = get_beta_param_from_mean_and_alpha(mean=mean, alpha=alpha)
8585
return cls(alpha=alpha, beta=beta)
@@ -90,7 +90,9 @@ def uninformative(cls) -> "Beta":
9090

9191
@classmethod
9292
def from_successes_and_failures(
93-
cls, successes: NUMERIC, failures: NUMERIC
93+
cls,
94+
successes: PositiveReal,
95+
failures: PositiveReal,
9496
) -> "Beta":
9597
"""Alternative constructor based on hyperparameter interpretation."""
9698
alpha = successes + 1
@@ -112,8 +114,8 @@ class Binomial(DiscretePlotMixin, SliceMixin):
112114
113115
"""
114116

115-
n: NUMERIC
116-
p: NUMERIC
117+
n: Natural
118+
p: Probability
117119

118120
def __post_init__(self):
119121
if isinstance(self.n, np.ndarray):
@@ -176,8 +178,8 @@ class Multinomial(SliceMixin):
176178
177179
"""
178180

179-
n: NUMERIC
180-
p: NUMERIC
181+
n: Natural
182+
p: Probability
181183

182184
@property
183185
def dist(self):
@@ -211,7 +213,7 @@ class Exponential(ContinuousPlotDistMixin, SliceMixin):
211213
212214
"""
213215

214-
lam: NUMERIC
216+
lam: PositiveReal
215217

216218
@property
217219
def dist(self):
@@ -235,11 +237,15 @@ class Gamma(ContinuousPlotDistMixin, SliceMixin):
235237
beta: rate parameter
236238
"""
237239

238-
alpha: NUMERIC
239-
beta: NUMERIC
240+
alpha: PositiveReal
241+
beta: PositiveReal
240242

241243
@classmethod
242-
def from_occurrences_in_intervals(cls, occurrences: NUMERIC, intervals: NUMERIC):
244+
def from_occurrences_in_intervals(
245+
cls,
246+
occurrences: PositiveReal,
247+
intervals: PositiveReal,
248+
) -> "Gamma":
243249
return cls(alpha=occurrences, beta=intervals)
244250

245251
@property
@@ -262,8 +268,8 @@ class NegativeBinomial(DiscretePlotMixin, SliceMixin):
262268
263269
"""
264270

265-
n: NUMERIC
266-
p: NUMERIC
271+
n: Natural
272+
p: Probability
267273

268274
@property
269275
def dist(self):
@@ -286,9 +292,9 @@ class Hypergeometric(DiscretePlotMixin, SliceMixin):
286292
287293
"""
288294

289-
N: NUMERIC
290-
k: NUMERIC
291-
n: NUMERIC
295+
N: Natural
296+
k: Natural
297+
n: Natural
292298

293299
def __post_init__(self) -> None:
294300
if isinstance(self.N, np.ndarray):
@@ -310,7 +316,7 @@ class Poisson(DiscretePlotMixin, SliceMixin):
310316
311317
"""
312318

313-
lam: NUMERIC
319+
lam: PositiveReal
314320

315321
@property
316322
def dist(self):
@@ -338,9 +344,9 @@ class BetaBinomial(DiscretePlotMixin, SliceMixin):
338344
339345
"""
340346

341-
n: NUMERIC
342-
alpha: NUMERIC
343-
beta: NUMERIC
347+
n: Natural
348+
alpha: PositiveReal
349+
beta: PositiveReal
344350

345351
def __post_init__(self):
346352
if isinstance(self.n, np.ndarray):
@@ -364,9 +370,9 @@ class BetaNegativeBinomial(DiscretePlotMixin, SliceMixin):
364370
365371
"""
366372

367-
n: NUMERIC
368-
alpha: NUMERIC
369-
beta: NUMERIC
373+
n: Natural
374+
alpha: PositiveReal
375+
beta: PositiveReal
370376

371377
def __post_init__(self):
372378
if isinstance(self.n, np.ndarray):
@@ -393,7 +399,7 @@ class Geometric(DiscretePlotMixin, SliceMixin):
393399
394400
"""
395401

396-
p: NUMERIC
402+
p: Probability
397403
one_start: bool = True
398404

399405
@property
@@ -413,8 +419,8 @@ class BetaGeometric(DiscretePlotMixin, SliceMixin):
413419
414420
"""
415421

416-
alpha: NUMERIC
417-
beta: NUMERIC
422+
alpha: PositiveReal
423+
beta: PositiveReal
418424
one_start: bool = True
419425

420426
@property
@@ -432,25 +438,33 @@ class Normal(ContinuousPlotDistMixin, SliceMixin):
432438
433439
"""
434440

435-
mu: NUMERIC
436-
sigma: NUMERIC
441+
mu: Real
442+
sigma: PositiveReal
437443

438444
@property
439445
def dist(self):
440446
return stats.norm(self.mu, self.sigma)
441447

442448
@classmethod
443-
def uninformative(cls, sigma: NUMERIC = 1) -> "Normal":
449+
def uninformative(cls, sigma: PositiveReal = 1.0) -> "Normal":
444450
"""Uninformative normal distribution."""
445451
return cls(mu=0, sigma=sigma)
446452

447453
@classmethod
448-
def from_mean_and_variance(cls, mean: NUMERIC, variance: NUMERIC) -> "Normal":
454+
def from_mean_and_variance(
455+
cls,
456+
mean: Real,
457+
variance: PositiveReal,
458+
) -> "Normal":
449459
"""Alternative constructor from mean and variance."""
450460
return cls(mu=mean, sigma=variance**0.5)
451461

452462
@classmethod
453-
def from_mean_and_precision(cls, mean: NUMERIC, precision: NUMERIC) -> "Normal":
463+
def from_mean_and_precision(
464+
cls,
465+
mean: Real,
466+
precision: PositiveReal,
467+
) -> "Normal":
454468
"""Alternative constructor from mean and precision."""
455469
return cls(mu=mean, sigma=precision**-0.5)
456470

@@ -495,8 +509,8 @@ class Uniform(ContinuousPlotDistMixin, SliceMixin):
495509
496510
"""
497511

498-
low: NUMERIC
499-
high: NUMERIC
512+
low: Real
513+
high: Real
500514

501515
def __post_init__(self):
502516
self.min_value = self.low
@@ -517,8 +531,8 @@ class Pareto(ContinuousPlotDistMixin, SliceMixin):
517531
518532
"""
519533

520-
x_m: NUMERIC
521-
alpha: NUMERIC
534+
x_m: Real
535+
alpha: Real
522536

523537
@property
524538
def dist(self):
@@ -535,8 +549,8 @@ class InverseGamma(ContinuousPlotDistMixin, SliceMixin):
535549
536550
"""
537551

538-
alpha: NUMERIC
539-
beta: NUMERIC
552+
alpha: Real
553+
beta: Real
540554

541555
@property
542556
def dist(self):
@@ -558,11 +572,11 @@ class NormalInverseGamma:
558572
559573
"""
560574

561-
mu: NUMERIC
562-
alpha: NUMERIC
563-
beta: NUMERIC
575+
mu: Real
576+
alpha: PositiveReal
577+
beta: PositiveReal
564578
delta_inverse: NUMERIC | None = None
565-
nu: NUMERIC | None = None
579+
nu: PositiveReal | None = None
566580

567581
def __post_init__(self) -> None:
568582
if self.delta_inverse is None and self.nu is None:
@@ -678,9 +692,9 @@ class StudentT(ContinuousPlotDistMixin, SliceMixin):
678692
679693
"""
680694

681-
mu: NUMERIC
682-
sigma: NUMERIC
683-
nu: NUMERIC
695+
mu: Real
696+
sigma: PositiveReal
697+
nu: PositiveReal
684698

685699
@property
686700
def dist(self):
@@ -725,8 +739,8 @@ class Lomax(ContinuousPlotDistMixin, SliceMixin):
725739
726740
"""
727741

728-
alpha: NUMERIC
729-
lam: NUMERIC
742+
alpha: PositiveReal
743+
lam: PositiveReal
730744

731745
@property
732746
def dist(self):
@@ -744,9 +758,9 @@ class CompoundGamma(ContinuousPlotDistMixin, SliceMixin):
744758
745759
"""
746760

747-
alpha: NUMERIC
748-
beta: NUMERIC
749-
lam: NUMERIC
761+
alpha: PositiveReal
762+
beta: PositiveReal
763+
lam: PositiveReal
750764

751765
@property
752766
def dist(self):
@@ -879,8 +893,8 @@ class VonMises(ContinuousPlotDistMixin, SliceMixin):
879893
880894
"""
881895

882-
mu: NUMERIC
883-
kappa: NUMERIC
896+
mu: Real
897+
kappa: PositiveReal
884898

885899
def __post_init__(self) -> None:
886900
self.min_value = -np.pi
@@ -961,8 +975,8 @@ class ScaledInverseChiSquared(ContinuousPlotDistMixin, SliceMixin):
961975
962976
"""
963977

964-
nu: NUMERIC
965-
sigma2: NUMERIC
978+
nu: PositiveReal
979+
sigma2: PositiveReal
966980

967981
@classmethod
968982
def from_inverse_gamma(
@@ -1299,8 +1313,8 @@ class LogNormal(ContinuousPlotDistMixin, SliceMixin):
12991313
13001314
"""
13011315

1302-
mu: NUMERIC
1303-
sigma: NUMERIC
1316+
mu: Real
1317+
sigma: PositiveReal
13041318

13051319
@property
13061320
def dist(self):
@@ -1346,8 +1360,8 @@ class Weibull(ContinuousPlotDistMixin, SliceMixin):
13461360
13471361
"""
13481362

1349-
beta: NUMERIC
1350-
theta: NUMERIC
1363+
beta: PositiveReal
1364+
theta: PositiveReal
13511365

13521366
@property
13531367
def dist(self):

0 commit comments

Comments
 (0)