Skip to content

Commit 118be0f

Browse files
committed
Add test for Blockwise logp regression
1 parent 986738f commit 118be0f

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tests/distributions/test_multivariate.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
import numpy.random as npr
2020
import numpy.testing as npt
2121
import pytensor
22-
import pytensor.tensor as pt
2322
import pytest
2423
import scipy.special as sp
2524
import scipy.stats as st
2625

2726
from pytensor import tensor as pt
2827
from pytensor.tensor import TensorVariable
28+
from pytensor.tensor.blockwise import Blockwise
2929
from pytensor.tensor.random.utils import broadcast_params
3030
from pytensor.tensor.slinalg import Cholesky
3131

@@ -2387,6 +2387,21 @@ def test_mvnormal_no_cholesky_in_model_logp():
23872387
assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)
23882388

23892389

2390+
def test_mvnormal_blockwise_solve_opt():
2391+
"""Check that no blockwise show up in the d/logp graph of a 2D MvNormal with a single covariance.
2392+
2393+
See #6993
2394+
"""
2395+
with pm.Model() as m:
2396+
pm.MvNormal("y", mu=0, cov=pt.diag([2, 2]), shape=(3, 2))
2397+
2398+
logp = m.logp()
2399+
dlogp = pytensor.grad(logp, wrt=m.value_vars[0])
2400+
fn = m.compile_fn(inputs=m.value_vars, outs=[logp, dlogp], point_fn=False)
2401+
2402+
assert not any(isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes)
2403+
2404+
23902405
def test_mvnormal_mu_convenience():
23912406
"""Test that mu is broadcasted to the length of cov and provided a default of zero"""
23922407
x = pm.MvNormal.dist(cov=np.eye(3))

0 commit comments

Comments
 (0)