Skip to content

Commit 3a92af3

Browse files
committed
Add rewrite rewrite for solve with batched b
1 parent 1106f7c commit 3a92af3

File tree

3 files changed

+135
-4
lines changed

3 files changed

+135
-4
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import cast
33

4-
from pytensor.graph.rewriting.basic import node_rewriter
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
55
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
66
from pytensor.tensor.blas import Dot22
77
from pytensor.tensor.blockwise import Blockwise
@@ -13,7 +13,14 @@
1313
register_specialize,
1414
register_stabilize,
1515
)
16-
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
16+
from pytensor.tensor.slinalg import (
17+
Cholesky,
18+
Solve,
19+
SolveBase,
20+
cholesky,
21+
solve,
22+
solve_triangular,
23+
)
1724

1825

1926
logger = logging.getLogger(__name__)
@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
131138
]
132139

133140

141+
@register_stabilize
142+
@register_specialize
143+
@node_rewriter([Blockwise])
144+
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
145+
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
146+
147+
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
148+
Only the last two dimensions of `b` and the output are swapped.
149+
"""
150+
core_op = node.op.core_op
151+
152+
if not isinstance(core_op, SolveBase):
153+
return None
154+
155+
if node.op.core_op.b_ndim != 1:
156+
return None
157+
158+
[a, b] = node.inputs
159+
160+
# Check `b` is actually batched
161+
if b.type.ndim == 1:
162+
return None
163+
164+
# Check `a` is a matrix (possibly with degenerate dims on the left)
165+
a_bcast_batch_dims = a.type.broadcastable[:-2]
166+
if not all(a_bcast_batch_dims):
167+
return None
168+
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
169+
elif len(a_bcast_batch_dims):
170+
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))
171+
172+
# Recreate solve Op with b_ndim=2
173+
props = core_op._props_dict()
174+
props["b_ndim"] = 2
175+
new_core_op = type(core_op)(**props)
176+
matrix_b_solve = Blockwise(new_core_op)
177+
178+
# Apply the rewrite
179+
new_solve = _T(matrix_b_solve(a, _T(b)))
180+
181+
old_solve = node.outputs[0]
182+
copy_stack_trace(old_solve, new_solve)
183+
184+
return [new_solve]
185+
186+
134187
@register_canonicalize
135188
@register_stabilize
136189
@register_specialize

tests/tensor/rewriting/test_linalg.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from functools import partial
22

33
import numpy as np
4-
import numpy.linalg
54
import pytest
65
import scipy.linalg
76
from numpy.testing import assert_allclose
@@ -17,7 +16,16 @@
1716
from pytensor.tensor.math import _allclose, dot, matmul
1817
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
1918
from pytensor.tensor.rewriting.linalg import inv_as_solve
20-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
19+
from pytensor.tensor.slinalg import (
20+
Cholesky,
21+
Solve,
22+
SolveBase,
23+
SolveTriangular,
24+
cho_solve,
25+
cholesky,
26+
solve,
27+
solve_triangular,
28+
)
2129
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
2230
from tests import unittest_tools as utt
2331
from tests.test_rop import break_op
@@ -231,3 +239,70 @@ def test_local_det_chol():
231239
f = function([X], [L, det_X, X])
232240
nodes = f.maker.fgraph.toposort()
233241
assert not any(isinstance(node, Det) for node in nodes)
242+
243+
244+
class TestBatchedVectorBSolveToMatrixBSolve:
245+
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
246+
247+
@staticmethod
248+
def any_vector_b_solve(fn):
249+
return any(
250+
(
251+
isinstance(node.op, Blockwise)
252+
and isinstance(node.op.core_op, SolveBase)
253+
and node.op.core_op.b_ndim == 1
254+
)
255+
for node in fn.maker.fgraph.apply_nodes
256+
)
257+
258+
@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
259+
def test_valid_cases(self, solve_op):
260+
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))
261+
262+
a = tensor(shape=(None, None))
263+
b = tensor(shape=(None, None, None))
264+
265+
if solve_op is cho_solve:
266+
# cho_solves expects a tuple (a, lower) as the first input
267+
out = solve_op((a, True), b, b_ndim=1)
268+
else:
269+
out = solve_op(a, b, b_ndim=1)
270+
271+
mode = get_default_mode().excluding(self.rewrite_name)
272+
ref_fn = pytensor.function([a, b], out, mode=mode)
273+
assert self.any_vector_b_solve(ref_fn)
274+
275+
mode = get_default_mode().including(self.rewrite_name)
276+
opt_fn = pytensor.function([a, b], out, mode=mode)
277+
assert not self.any_vector_b_solve(opt_fn)
278+
279+
test_a = rng.normal(size=(3, 3)).astype(config.floatX)
280+
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
281+
np.testing.assert_allclose(
282+
opt_fn(test_a, test_b),
283+
ref_fn(test_a, test_b),
284+
rtol=1e-7 if config.floatX == "float64" else 1e-5,
285+
)
286+
287+
def test_invalid_batched_a(self):
288+
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))
289+
290+
# Rewrite is not applicable if a has batched dims
291+
a = tensor(shape=(None, None, None))
292+
b = tensor(shape=(None, None, None))
293+
294+
out = solve(a, b, b_ndim=1)
295+
296+
mode = get_default_mode().including(self.rewrite_name)
297+
opt_fn = pytensor.function([a, b], out, mode=mode)
298+
assert self.any_vector_b_solve(opt_fn)
299+
300+
ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")
301+
302+
test_a = rng.normal(size=(5, 3, 3)).astype(config.floatX)
303+
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
304+
np.testing.assert_allclose(
305+
opt_fn(test_a, test_b),
306+
ref_fn(test_a, test_b),
307+
rtol=1e-7 if config.floatX == "float64" else 1e-5,
308+
)

tests/tensor/test_blockwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def test_perform(self):
257257
np.testing.assert_allclose(
258258
pt_func(*vec_inputs_testvals),
259259
np_func(*vec_inputs_testvals),
260+
rtol=1e-7 if config.floatX == "float64" else 1e-5,
261+
atol=1e-7 if config.floatX == "float64" else 1e-5,
260262
)
261263

262264
def test_grad(self):
@@ -288,6 +290,7 @@ def test_grad(self):
288290
np.testing.assert_allclose(
289291
pt_out,
290292
np_out,
293+
rtol=1e-7 if config.floatX == "float64" else 1e-5,
291294
atol=1e-6 if config.floatX == "float64" else 1e-5,
292295
)
293296

0 commit comments

Comments
 (0)