Skip to content

Commit 275c145

Browse files
authored
Add ExGaussian moment (#5165)
1 parent 5dacdae commit 275c145

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

pymc/distributions/continuous.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -2779,15 +2779,18 @@ def dist(cls, mu=0.0, sigma=None, nu=None, sd=None, *args, **kwargs):
27792779
sigma = at.as_tensor_variable(floatX(sigma))
27802780
nu = at.as_tensor_variable(floatX(nu))
27812781

2782-
# sd = sigma
2783-
# mean = mu + nu
2784-
# variance = (sigma ** 2) + (nu ** 2)
2785-
27862782
assert_negative_support(sigma, "sigma", "ExGaussian")
27872783
assert_negative_support(nu, "nu", "ExGaussian")
27882784

27892785
return super().dist([mu, sigma, nu], *args, **kwargs)
27902786

2787+
def get_moment(rv, size, mu, sigma, nu):
2788+
mu, nu, _ = at.broadcast_arrays(mu, nu, sigma)
2789+
moment = mu + nu
2790+
if not rv_size_is_none(size):
2791+
moment = at.full(size, moment)
2792+
return moment
2793+
27912794
def logp(value, mu, sigma, nu):
27922795
"""
27932796
Calculate log-probability of ExGaussian distribution at specified value.

pymc/tests/test_distributions_moments.py

+17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ChiSquared,
1212
Constant,
1313
DiscreteUniform,
14+
ExGaussian,
1415
Exponential,
1516
Flat,
1617
Gamma,
@@ -541,6 +542,22 @@ def test_logistic_moment(mu, s, size, expected):
541542
assert_moment_is_expected(model, expected)
542543

543544

545+
@pytest.mark.parametrize(
546+
"mu, nu, sigma, size, expected",
547+
[
548+
(1, 1, None, None, 2),
549+
(1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)),
550+
(1, 1, None, 5, np.full(5, 2)),
551+
(1, np.arange(1, 6), None, None, np.arange(2, 7)),
552+
(1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))),
553+
],
554+
)
555+
def test_exgaussian_moment(mu, nu, sigma, size, expected):
556+
with Model() as model:
557+
ExGaussian("x", mu=mu, sigma=sigma, nu=nu, size=size)
558+
assert_moment_is_expected(model, expected)
559+
560+
544561
@pytest.mark.parametrize(
545562
"p, size, expected",
546563
[

0 commit comments

Comments
 (0)