diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index e681eb6a17..475e454037 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,7 +1,7 @@ import logging from typing import cast -from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise @@ -13,7 +13,14 @@ register_specialize, register_stabilize, ) -from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular +from pytensor.tensor.slinalg import ( + Cholesky, + Solve, + SolveBase, + cholesky, + solve, + solve_triangular, +) logger = logging.getLogger(__name__) @@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node): ] +@register_stabilize +@register_specialize +@node_rewriter([Blockwise]) +def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): + """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T + + `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. + Only the last two dimensions of `b` and the output are swapped. + """ + core_op = node.op.core_op + + if not isinstance(core_op, SolveBase): + return None + + if node.op.core_op.b_ndim != 1: + return None + + [a, b] = node.inputs + + # Check `b` is actually batched + if b.type.ndim == 1: + return None + + # Check `a` is a matrix (possibly with degenerate dims on the left) + a_bcast_batch_dims = a.type.broadcastable[:-2] + if not all(a_bcast_batch_dims): + return None + # We squeeze degenerate dims, any that are still needed will be introduced by the new_solve + elif len(a_bcast_batch_dims): + a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims)))) + + # Recreate solve Op with b_ndim=2 + props = core_op._props_dict() + props["b_ndim"] = 2 + new_core_op = type(core_op)(**props) + matrix_b_solve = Blockwise(new_core_op) + + # Apply the rewrite + new_solve = _T(matrix_b_solve(a, _T(b))) + + old_solve = node.outputs[0] + copy_stack_trace(old_solve, new_solve) + + return [new_solve] + + @register_canonicalize @register_stabilize @register_specialize diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 498ed18bf9..e4c636f87c 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1,7 +1,6 @@ from functools import partial import numpy as np -import numpy.linalg import pytest import scipy.linalg from numpy.testing import assert_allclose @@ -17,7 +16,16 @@ from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.rewriting.linalg import inv_as_solve -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve +from pytensor.tensor.slinalg import ( + Cholesky, + Solve, + SolveBase, + SolveTriangular, + cho_solve, + cholesky, + solve, + solve_triangular, +) from pytensor.tensor.type import dmatrix, matrix, tensor, vector from tests import unittest_tools as utt from tests.test_rop import break_op @@ -231,3 +239,70 @@ def test_local_det_chol(): f = function([X], [L, det_X, X]) nodes = f.maker.fgraph.toposort() assert not any(isinstance(node, Det) for node in nodes) + + +class TestBatchedVectorBSolveToMatrixBSolve: + rewrite_name = "batched_vector_b_solve_to_matrix_b_solve" + + @staticmethod + def any_vector_b_solve(fn): + return any( + ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, SolveBase) + and node.op.core_op.b_ndim == 1 + ) + for node in fn.maker.fgraph.apply_nodes + ) + + @pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve)) + def test_valid_cases(self, solve_op): + rng = np.random.default_rng(sum(map(ord, solve_op.__name__))) + + a = tensor(shape=(None, None)) + b = tensor(shape=(None, None, None)) + + if solve_op is cho_solve: + # cho_solves expects a tuple (a, lower) as the first input + out = solve_op((a, True), b, b_ndim=1) + else: + out = solve_op(a, b, b_ndim=1) + + mode = get_default_mode().excluding(self.rewrite_name) + ref_fn = pytensor.function([a, b], out, mode=mode) + assert self.any_vector_b_solve(ref_fn) + + mode = get_default_mode().including(self.rewrite_name) + opt_fn = pytensor.function([a, b], out, mode=mode) + assert not self.any_vector_b_solve(opt_fn) + + test_a = rng.normal(size=(3, 3)).astype(config.floatX) + test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX) + np.testing.assert_allclose( + opt_fn(test_a, test_b), + ref_fn(test_a, test_b), + rtol=1e-7 if config.floatX == "float64" else 1e-5, + ) + + def test_invalid_batched_a(self): + rng = np.random.default_rng(sum(map(ord, self.rewrite_name))) + + # Rewrite is not applicable if a has batched dims + a = tensor(shape=(None, None, None)) + b = tensor(shape=(None, None, None)) + + out = solve(a, b, b_ndim=1) + + mode = get_default_mode().including(self.rewrite_name) + opt_fn = pytensor.function([a, b], out, mode=mode) + assert self.any_vector_b_solve(opt_fn) + + ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)") + + test_a = rng.normal(size=(5, 3, 3)).astype(config.floatX) + test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX) + np.testing.assert_allclose( + opt_fn(test_a, test_b), + ref_fn(test_a, test_b), + rtol=1e-7 if config.floatX == "float64" else 1e-5, + ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 1a36c57f45..b2d2dfda1f 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -9,10 +9,10 @@ from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node -from pytensor.tensor import tensor +from pytensor.tensor import diagonal, log, tensor from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature from pytensor.tensor.nlinalg import MatrixInverse -from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular def test_vectorize_blockwise(): @@ -88,7 +88,7 @@ def test_runtime_broadcast(mode): check_blockwise_runtime_broadcasting(mode) -class TestOp(Op): +class MyTestOp(Op): def make_node(self, *inputs): return Apply(self, inputs, [i.type() for i in inputs]) @@ -96,7 +96,7 @@ def perform(self, *args, **kwargs): raise NotImplementedError("Test Op should not be present in final graph") -test_op = TestOp() +test_op = MyTestOp() def test_vectorize_node_default_signature(): @@ -106,12 +106,12 @@ def test_vectorize_node_default_signature(): vect_node = vectorize_node(node, mat, mat) assert isinstance(vect_node.op, Blockwise) and isinstance( - vect_node.op.core_op, TestOp + vect_node.op.core_op, MyTestOp ) assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") with pytest.raises( - ValueError, match="Signature not provided nor found in core_op TestOp" + ValueError, match="Signature not provided nor found in core_op MyTestOp" ): Blockwise(test_op) @@ -138,7 +138,7 @@ def test_blockwise_shape(): shape_fn = pytensor.function([inp], out.shape) assert not any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) assert tuple(shape_fn(inp_test)) == (5, 3, 4) @@ -150,13 +150,13 @@ def test_blockwise_shape(): shape_fn = pytensor.function([inp], out.shape) assert any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) shape_fn = pytensor.function([inp], out.shape[:-1]) assert not any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) assert tuple(shape_fn(inp_test)) == (5, 4) @@ -174,20 +174,20 @@ def test_blockwise_shape(): shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) assert any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) shape_fn = pytensor.function([inp1, inp2], outs[0].shape) assert not any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) assert not any( - isinstance(getattr(n.op, "core_op", n.op), TestOp) + isinstance(getattr(n.op, "core_op", n.op), MyTestOp) for n in shape_fn.maker.fgraph.apply_nodes ) assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) @@ -257,6 +257,8 @@ def test_perform(self): np.testing.assert_allclose( pt_func(*vec_inputs_testvals), np_func(*vec_inputs_testvals), + rtol=1e-7 if config.floatX == "float64" else 1e-5, + atol=1e-7 if config.floatX == "float64" else 1e-5, ) def test_grad(self): @@ -288,6 +290,7 @@ def test_grad(self): np.testing.assert_allclose( pt_out, np_out, + rtol=1e-7 if config.floatX == "float64" else 1e-5, atol=1e-6 if config.floatX == "float64" else 1e-5, ) @@ -317,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester): class TestSolveMatrix(BlockwiseOpTester): core_op = Solve(lower=True, b_ndim=2) signature = "(m, m),(m, n) -> (m, n)" + + +@pytest.mark.parametrize( + "mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}" +) +@pytest.mark.parametrize( + "cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}" +) +def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark): + rng = np.random.default_rng(sum(map(ord, "batched_mvnormal"))) + + value_batch_shape = mu_batch_shape + if len(cov_batch_shape) > len(mu_batch_shape): + value_batch_shape = cov_batch_shape + + value = tensor("value", shape=(*value_batch_shape, 10)) + mu = tensor("mu", shape=(*mu_batch_shape, 10)) + cov = tensor("cov", shape=(*cov_batch_shape, 10, 10)) + + test_values = [ + rng.normal(size=value.type.shape), + rng.normal(size=mu.type.shape), + np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)), + ] + + chol_cov = cholesky(cov, lower=True, on_error="raise") + delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1) + quaddist = (delta_trans**2).sum(axis=-1) + diag = diagonal(chol_cov, axis1=-2, axis2=-1) + logdet = log(diag).sum(axis=-1) + k = value.shape[-1] + norm = -0.5 * k * (np.log(2 * np.pi)) + + logp = norm - 0.5 * quaddist - logdet + dlogp = grad(logp.sum(), wrt=[value, mu, cov]) + + fn = pytensor.function([value, mu, cov], [logp, *dlogp]) + benchmark(fn, *test_values)