Skip to content

Commit 4f4379c

Browse files
committed
Add mvnormal logp dlogp benchmark test
1 parent 3a92af3 commit 4f4379c

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

tests/tensor/test_blockwise.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from pytensor.gradient import grad
1010
from pytensor.graph import Apply, Op
1111
from pytensor.graph.replace import vectorize_node
12-
from pytensor.tensor import tensor
12+
from pytensor.tensor import diagonal, log, tensor
1313
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
1414
from pytensor.tensor.nlinalg import MatrixInverse
15-
from pytensor.tensor.slinalg import Cholesky, Solve
15+
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
1616

1717

1818
def test_vectorize_blockwise():
@@ -320,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester):
320320
class TestSolveMatrix(BlockwiseOpTester):
321321
core_op = Solve(lower=True, b_ndim=2)
322322
signature = "(m, m),(m, n) -> (m, n)"
323+
324+
325+
@pytest.mark.parametrize(
326+
"mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}"
327+
)
328+
@pytest.mark.parametrize(
329+
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
330+
)
331+
def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark):
332+
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))
333+
334+
value_batch_shape = mu_batch_shape
335+
if len(cov_batch_shape) > len(mu_batch_shape):
336+
value_batch_shape = cov_batch_shape
337+
338+
value = tensor("value", shape=(*value_batch_shape, 10))
339+
mu = tensor("mu", shape=(*mu_batch_shape, 10))
340+
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))
341+
342+
test_values = [
343+
rng.normal(size=value.type.shape),
344+
rng.normal(size=mu.type.shape),
345+
np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)),
346+
]
347+
348+
chol_cov = cholesky(cov, lower=True, on_error="raise")
349+
delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1)
350+
quaddist = (delta_trans**2).sum(axis=-1)
351+
diag = diagonal(chol_cov, axis1=-2, axis2=-1)
352+
logdet = log(diag).sum(axis=-1)
353+
k = value.shape[-1]
354+
norm = -0.5 * k * (np.log(2 * np.pi))
355+
356+
logp = norm - 0.5 * quaddist - logdet
357+
dlogp = grad(logp.sum(), wrt=[value, mu, cov])
358+
359+
fn = pytensor.function([value, mu, cov], [logp, *dlogp])
360+
benchmark(fn, *test_values)

0 commit comments

Comments
 (0)