Skip to content

Commit 3ca2226

Browse files
committed
Add test for Blockwise logp regression
1 parent 01ddcb8 commit 3ca2226

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

tests/distributions/test_multivariate.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
# limitations under the License.
1414

1515
import functools as ft
16-
import re
1716
import warnings
1817

1918
import numpy as np
2019
import numpy.random as npr
2120
import numpy.testing as npt
2221
import pytensor
23-
import pytensor.tensor as pt
2422
import pytest
2523
import scipy.special as sp
2624
import scipy.stats as st
2725

2826
from pytensor import tensor as pt
2927
from pytensor.tensor import TensorVariable
28+
from pytensor.tensor.blockwise import Blockwise
3029
from pytensor.tensor.random.utils import broadcast_params
3130
from pytensor.tensor.slinalg import Cholesky
3231

@@ -2314,6 +2313,21 @@ def test_mvnormal_no_cholesky_in_model_logp():
23142313
assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)
23152314

23162315

2316+
def test_mvnormal_blockwise_solve_opt():
2317+
"""Check that no blockwise show up in the d/logp graph of a 2D MvNormal with a single covariance.
2318+
2319+
See #6993
2320+
"""
2321+
with pm.Model() as m:
2322+
pm.MvNormal("y", mu=0, cov=pt.diag([2, 2]), shape=(3, 2))
2323+
2324+
logp = m.logp()
2325+
dlogp = pytensor.grad(logp, wrt=m.value_vars[0])
2326+
fn = m.compile_fn(inputs=m.value_vars, outs=[logp, dlogp], point_fn=False)
2327+
2328+
assert not any(isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes)
2329+
2330+
23172331
def test_mvnormal_mu_convenience():
23182332
"""Test that mu is broadcasted to the length of cov and provided a default of zero"""
23192333
x = pm.MvNormal.dist(cov=np.eye(3))

0 commit comments

Comments
 (0)