diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 027d0b1915..88de323c12 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -13,20 +13,19 @@ # limitations under the License. import functools as ft -import re import warnings import numpy as np import numpy.random as npr import numpy.testing as npt import pytensor -import pytensor.tensor as pt import pytest import scipy.special as sp import scipy.stats as st from pytensor import tensor as pt from pytensor.tensor import TensorVariable +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.slinalg import Cholesky @@ -2314,6 +2313,21 @@ def test_mvnormal_no_cholesky_in_model_logp(): assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph) +def test_mvnormal_blockwise_solve_opt(): + """Check that no blockwise show up in the d/logp graph of a 2D MvNormal with a single covariance. + + See #6993 + """ + with pm.Model() as m: + pm.MvNormal("y", mu=0, cov=pt.diag([2, 2]), shape=(3, 2)) + + logp = m.logp() + dlogp = pytensor.grad(logp, wrt=m.value_vars[0]) + fn = m.compile_fn(inputs=m.value_vars, outs=[logp, dlogp], point_fn=False) + + assert not any(isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes) + + def test_mvnormal_mu_convenience(): """Test that mu is broadcasted to the length of cov and provided a default of zero""" x = pm.MvNormal.dist(cov=np.eye(3))