Skip to content

Add specialization rewrite for solve with batched b #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import cast

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
Expand All @@ -13,7 +13,14 @@
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cholesky,
solve,
solve_triangular,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
]


@register_stabilize
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make this a specialize only rewrite but for now it's in stabilize because pymc includes those before calling grad, and otherwise we still end up with messy unoptimized blockwise graphs

@register_specialize
@node_rewriter([Blockwise])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T

`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
Only the last two dimensions of `b` and the output are swapped.
"""
core_op = node.op.core_op

if not isinstance(core_op, SolveBase):
return None

if node.op.core_op.b_ndim != 1:
return None

[a, b] = node.inputs

# Check `b` is actually batched
if b.type.ndim == 1:
return None

# Check `a` is a matrix (possibly with degenerate dims on the left)
a_bcast_batch_dims = a.type.broadcastable[:-2]
if not all(a_bcast_batch_dims):
return None
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
elif len(a_bcast_batch_dims):
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))

# Recreate solve Op with b_ndim=2
props = core_op._props_dict()
props["b_ndim"] = 2
new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op)

# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we could put any batched b dimension here when there are multiple of them. We may choose the larger one to reduce outer looping


old_solve = node.outputs[0]
copy_stack_trace(old_solve, new_solve)

return [new_solve]


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down
79 changes: 77 additions & 2 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial

import numpy as np
import numpy.linalg
import pytest
import scipy.linalg
from numpy.testing import assert_allclose
Expand All @@ -17,7 +16,16 @@
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
SolveTriangular,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
Expand Down Expand Up @@ -231,3 +239,70 @@ def test_local_det_chol():
f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)


class TestBatchedVectorBSolveToMatrixBSolve:
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"

@staticmethod
def any_vector_b_solve(fn):
return any(
(
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveBase)
and node.op.core_op.b_ndim == 1
)
for node in fn.maker.fgraph.apply_nodes
)

@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
def test_valid_cases(self, solve_op):
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))

a = tensor(shape=(None, None))
b = tensor(shape=(None, None, None))

if solve_op is cho_solve:
# cho_solves expects a tuple (a, lower) as the first input
out = solve_op((a, True), b, b_ndim=1)
else:
out = solve_op(a, b, b_ndim=1)

mode = get_default_mode().excluding(self.rewrite_name)
ref_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(ref_fn)

mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert not self.any_vector_b_solve(opt_fn)

test_a = rng.normal(size=(3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)

def test_invalid_batched_a(self):
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))

# Rewrite is not applicable if a has batched dims
a = tensor(shape=(None, None, None))
b = tensor(shape=(None, None, None))

out = solve(a, b, b_ndim=1)

mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(opt_fn)

ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")

test_a = rng.normal(size=(5, 3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
65 changes: 53 additions & 12 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular


def test_vectorize_blockwise():
Expand Down Expand Up @@ -88,15 +88,15 @@ def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode)


class TestOp(Op):
class MyTestOp(Op):
def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs])

def perform(self, *args, **kwargs):
raise NotImplementedError("Test Op should not be present in final graph")


test_op = TestOp()
test_op = MyTestOp()


def test_vectorize_node_default_signature():
Expand All @@ -106,12 +106,12 @@ def test_vectorize_node_default_signature():

vect_node = vectorize_node(node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, TestOp
vect_node.op.core_op, MyTestOp
)
assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)")

with pytest.raises(
ValueError, match="Signature not provided nor found in core_op TestOp"
ValueError, match="Signature not provided nor found in core_op MyTestOp"
):
Blockwise(test_op)

Expand All @@ -138,7 +138,7 @@ def test_blockwise_shape():

shape_fn = pytensor.function([inp], out.shape)
assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)
assert tuple(shape_fn(inp_test)) == (5, 3, 4)
Expand All @@ -150,13 +150,13 @@ def test_blockwise_shape():

shape_fn = pytensor.function([inp], out.shape)
assert any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)

shape_fn = pytensor.function([inp], out.shape[:-1])
assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)
assert tuple(shape_fn(inp_test)) == (5, 4)
Expand All @@ -174,20 +174,20 @@ def test_blockwise_shape():

shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs])
assert any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)

shape_fn = pytensor.function([inp1, inp2], outs[0].shape)
assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)
assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4)

shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]])
assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp)
isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes
)
assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4)
Expand Down Expand Up @@ -257,6 +257,8 @@ def test_perform(self):
np.testing.assert_allclose(
pt_func(*vec_inputs_testvals),
np_func(*vec_inputs_testvals),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-7 if config.floatX == "float64" else 1e-5,
)

def test_grad(self):
Expand Down Expand Up @@ -288,6 +290,7 @@ def test_grad(self):
np.testing.assert_allclose(
pt_out,
np_out,
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-6 if config.floatX == "float64" else 1e-5,
)

Expand Down Expand Up @@ -317,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester):
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"


@pytest.mark.parametrize(
"mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}"
)
@pytest.mark.parametrize(
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
)
def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark):
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))

value_batch_shape = mu_batch_shape
if len(cov_batch_shape) > len(mu_batch_shape):
value_batch_shape = cov_batch_shape

value = tensor("value", shape=(*value_batch_shape, 10))
mu = tensor("mu", shape=(*mu_batch_shape, 10))
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))

test_values = [
rng.normal(size=value.type.shape),
rng.normal(size=mu.type.shape),
np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)),
]

chol_cov = cholesky(cov, lower=True, on_error="raise")
delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
diag = diagonal(chol_cov, axis1=-2, axis2=-1)
logdet = log(diag).sum(axis=-1)
k = value.shape[-1]
norm = -0.5 * k * (np.log(2 * np.pi))

logp = norm - 0.5 * quaddist - logdet
dlogp = grad(logp.sum(), wrt=[value, mu, cov])

fn = pytensor.function([value, mu, cov], [logp, *dlogp])
benchmark(fn, *test_values)