Skip to content

Commit 458f28e

Browse files
danhphanricardoV94
authored andcommitted
remove MultinomialRV override
1 parent bae5087 commit 458f28e

File tree

1 file changed

+1
-25
lines changed

1 file changed

+1
-25
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from aesara.sparse.basic import sp_sum
3131
from aesara.tensor import gammaln, sigmoid
3232
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
33-
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
33+
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
3434
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
3535
from aesara.tensor.random.utils import broadcast_params, normalize_size_param
3636
from aesara.tensor.slinalg import Cholesky
@@ -490,30 +490,6 @@ def logp(value, a):
490490
)
491491

492492

493-
class MultinomialRV(MultinomialRV):
494-
"""Aesara's `MultinomialRV` doesn't broadcast; this one does."""
495-
496-
@classmethod
497-
def rng_fn(cls, rng, n, p, size):
498-
if n.ndim > 0 or p.ndim > 1:
499-
n, p = broadcast_params([n, p], cls.ndims_params)
500-
size = tuple(size or ())
501-
502-
if size:
503-
n = np.broadcast_to(n, size)
504-
p = np.broadcast_to(p, size + (p.shape[-1],))
505-
506-
res = np.empty(p.shape)
507-
for idx in np.ndindex(p.shape[:-1]):
508-
res[idx] = rng.multinomial(n[idx], p[idx])
509-
return res
510-
else:
511-
return rng.multinomial(n, p, size=size)
512-
513-
514-
multinomial = MultinomialRV()
515-
516-
517493
class Multinomial(Discrete):
518494
r"""
519495
Multinomial log-likelihood.

0 commit comments

Comments
 (0)