|
9 | 9 | from pytensor.gradient import grad
|
10 | 10 | from pytensor.graph import Apply, Op
|
11 | 11 | from pytensor.graph.replace import vectorize_node
|
12 |
| -from pytensor.tensor import tensor |
| 12 | +from pytensor.tensor import diagonal, log, tensor |
13 | 13 | from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
|
14 | 14 | from pytensor.tensor.nlinalg import MatrixInverse
|
15 |
| -from pytensor.tensor.slinalg import Cholesky, Solve |
| 15 | +from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular |
16 | 16 |
|
17 | 17 |
|
18 | 18 | def test_vectorize_blockwise():
|
@@ -320,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester):
|
320 | 320 | class TestSolveMatrix(BlockwiseOpTester):
|
321 | 321 | core_op = Solve(lower=True, b_ndim=2)
|
322 | 322 | signature = "(m, m),(m, n) -> (m, n)"
|
| 323 | + |
| 324 | + |
| 325 | +@pytest.mark.parametrize( |
| 326 | + "mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}" |
| 327 | +) |
| 328 | +@pytest.mark.parametrize( |
| 329 | + "cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}" |
| 330 | +) |
| 331 | +def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark): |
| 332 | + rng = np.random.default_rng(sum(map(ord, "batched_mvnormal"))) |
| 333 | + |
| 334 | + value_batch_shape = mu_batch_shape |
| 335 | + if len(cov_batch_shape) > len(mu_batch_shape): |
| 336 | + value_batch_shape = cov_batch_shape |
| 337 | + |
| 338 | + value = tensor("value", shape=(*value_batch_shape, 10)) |
| 339 | + mu = tensor("mu", shape=(*mu_batch_shape, 10)) |
| 340 | + cov = tensor("cov", shape=(*cov_batch_shape, 10, 10)) |
| 341 | + |
| 342 | + test_values = [ |
| 343 | + rng.normal(size=value.type.shape), |
| 344 | + rng.normal(size=mu.type.shape), |
| 345 | + np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)), |
| 346 | + ] |
| 347 | + |
| 348 | + chol_cov = cholesky(cov, lower=True, on_error="raise") |
| 349 | + delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1) |
| 350 | + quaddist = (delta_trans**2).sum(axis=-1) |
| 351 | + diag = diagonal(chol_cov, axis1=-2, axis2=-1) |
| 352 | + logdet = log(diag).sum(axis=-1) |
| 353 | + k = value.shape[-1] |
| 354 | + norm = -0.5 * k * (np.log(2 * np.pi)) |
| 355 | + |
| 356 | + logp = norm - 0.5 * quaddist - logdet |
| 357 | + dlogp = grad(logp.sum(), wrt=[value, mu, cov]) |
| 358 | + |
| 359 | + fn = pytensor.function([value, mu, cov], [logp, *dlogp]) |
| 360 | + benchmark(fn, *test_values) |
0 commit comments