Skip to content

Commit e709c2c

Browse files
Add Multinomial moments (#5201)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8d09a1c commit e709c2c

File tree

3 files changed

+54
-27
lines changed

3 files changed

+54
-27
lines changed

pymc/distributions/multivariate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,21 @@ def dist(cls, n, p, *args, **kwargs):
525525

526526
return super().dist([n, p], *args, **kwargs)
527527

528+
def get_moment(rv, size, n, p):
529+
if p.ndim > 1:
530+
n = at.shape_padright(n)
531+
if (p.ndim == 1) & (n.ndim > 0):
532+
n = at.shape_padright(n)
533+
p = at.shape_padleft(p)
534+
mode = at.round(n * p)
535+
diff = n - at.sum(mode, axis=-1, keepdims=True)
536+
inc_bool_arr = at.abs_(diff) > 0
537+
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
538+
if not rv_size_is_none(size):
539+
output_size = at.concatenate([size, p.shape])
540+
mode = at.full(output_size, mode)
541+
return mode
542+
528543
def logp(value, n, p):
529544
"""
530545
Calculate log-probability of Multinomial distribution

pymc/tests/test_distributions.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,25 +2160,6 @@ def test_multinomial(self, n):
21602160
Multinomial, Vector(Nat, n), {"p": Simplex(n), "n": Nat}, multinomial_logpdf
21612161
)
21622162

2163-
@pytest.mark.skip(reason="Moment calculations have not been refactored yet")
2164-
@pytest.mark.parametrize(
2165-
"p,n",
2166-
[
2167-
[[0.25, 0.25, 0.25, 0.25], 1],
2168-
[[0.3, 0.6, 0.05, 0.05], 2],
2169-
[[0.3, 0.6, 0.05, 0.05], 10],
2170-
],
2171-
)
2172-
def test_multinomial_mode(self, p, n):
2173-
_p = np.array(p)
2174-
with Model() as model:
2175-
m = Multinomial("m", n, _p, _p.shape)
2176-
assert_allclose(m.distribution.mode.eval().sum(), n)
2177-
_p = np.array([p, p])
2178-
with Model() as model:
2179-
m = Multinomial("m", n, _p, _p.shape)
2180-
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
2181-
21822163
@pytest.mark.parametrize(
21832164
"p, size, n",
21842165
[
@@ -2206,14 +2187,6 @@ def test_multinomial_random(self, p, size, n):
22062187

22072188
assert m.eval().shape == size + p.shape
22082189

2209-
@pytest.mark.skip(reason="Moment calculations have not been refactored yet")
2210-
def test_multinomial_mode_with_shape(self):
2211-
n = [1, 10]
2212-
p = np.asarray([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]])
2213-
with Model() as model:
2214-
m = Multinomial("m", n=n, p=p, size=(2, 4))
2215-
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
2216-
22172190
def test_multinomial_vec(self):
22182191
vals = np.array([[2, 4, 4], [3, 3, 4]])
22192192
p = np.array([0.2, 0.3, 0.5])

pymc/tests/test_distributions_moments.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
LogNormal,
4242
MatrixNormal,
4343
Moyal,
44+
Multinomial,
4445
MvStudentT,
4546
NegativeBinomial,
4647
Normal,
@@ -1104,6 +1105,44 @@ def test_polyagamma_moment(h, z, size, expected):
11041105
assert_moment_is_expected(model, expected)
11051106

11061107

1108+
@pytest.mark.parametrize(
1109+
"p, n, size, expected",
1110+
[
1111+
(np.array([0.25, 0.25, 0.25, 0.25]), 1, None, np.array([1, 0, 0, 0])),
1112+
(np.array([0.3, 0.6, 0.05, 0.05]), 2, None, np.array([1, 1, 0, 0])),
1113+
(np.array([0.3, 0.6, 0.05, 0.05]), 10, None, np.array([4, 6, 0, 0])),
1114+
(
1115+
np.array([[0.3, 0.6, 0.05, 0.05], [0.25, 0.25, 0.25, 0.25]]),
1116+
10,
1117+
None,
1118+
np.array([[4, 6, 0, 0], [4, 2, 2, 2]]),
1119+
),
1120+
(
1121+
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
1122+
np.array([1, 10]),
1123+
None,
1124+
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
1125+
),
1126+
(
1127+
np.array([0.26, 0.26, 0.26, 0.22]),
1128+
np.array([1, 10]),
1129+
None,
1130+
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
1131+
),
1132+
(
1133+
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
1134+
np.array([1, 10]),
1135+
2,
1136+
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
1137+
),
1138+
],
1139+
)
1140+
def test_multinomial_moment(p, n, size, expected):
1141+
with Model() as model:
1142+
Multinomial("x", n=n, p=p, size=size)
1143+
assert_moment_is_expected(model, expected)
1144+
1145+
11071146
@pytest.mark.parametrize(
11081147
"psi, mu, alpha, size, expected",
11091148
[

0 commit comments

Comments
 (0)