Skip to content

Commit 223ad7a

Browse files
committed
Add specialization rewrite for solve with batched b
1 parent 7bb18f3 commit 223ad7a

File tree

2 files changed

+129
-3
lines changed

2 files changed

+129
-3
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import cast
33

4-
from pytensor.graph.rewriting.basic import node_rewriter
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
55
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
66
from pytensor.tensor.blas import Dot22
77
from pytensor.tensor.blockwise import Blockwise
@@ -13,7 +13,14 @@
1313
register_specialize,
1414
register_stabilize,
1515
)
16-
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
16+
from pytensor.tensor.slinalg import (
17+
Cholesky,
18+
Solve,
19+
SolveBase,
20+
cholesky,
21+
solve,
22+
solve_triangular,
23+
)
1724

1825

1926
logger = logging.getLogger(__name__)
@@ -131,6 +138,51 @@ def generic_solve_to_solve_triangular(fgraph, node):
131138
]
132139

133140

141+
@register_specialize
142+
@node_rewriter([Blockwise])
143+
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
144+
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
145+
146+
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
147+
Only the last two dimensions of `b` and the output are swapped.
148+
"""
149+
core_op = node.op.core_op
150+
151+
if not isinstance(core_op, SolveBase):
152+
return None
153+
154+
if node.op.core_op.b_ndim != 1:
155+
return None
156+
157+
[a, b] = node.inputs
158+
159+
# Check `b` is actually batched
160+
if b.type.ndim == 1:
161+
return None
162+
163+
# Check `a` is a matrix (possibly with degenerate dims on the left)
164+
a_bcast_batch_dims = a.type.broadcastable[:-2]
165+
if not all(a_bcast_batch_dims):
166+
return None
167+
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
168+
elif len(a_bcast_batch_dims):
169+
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))
170+
171+
# Recreate solve Op with b_ndim=2
172+
props = core_op._props_dict()
173+
props["b_ndim"] = 2
174+
new_core_op = type(core_op)(**props)
175+
matrix_b_solve = Blockwise(new_core_op)
176+
177+
# Apply the rewrite
178+
new_solve = _T(matrix_b_solve(a, _T(b)))
179+
180+
old_solve = node.outputs[0]
181+
copy_stack_trace(old_solve, new_solve)
182+
183+
return [new_solve]
184+
185+
134186
@register_canonicalize
135187
@register_stabilize
136188
@register_specialize

tests/tensor/rewriting/test_linalg.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
from pytensor.tensor.math import _allclose, dot, matmul
1818
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
1919
from pytensor.tensor.rewriting.linalg import inv_as_solve
20-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
20+
from pytensor.tensor.slinalg import (
21+
Cholesky,
22+
Solve,
23+
SolveBase,
24+
SolveTriangular,
25+
cho_solve,
26+
cholesky,
27+
solve,
28+
solve_triangular,
29+
)
2130
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
2231
from tests import unittest_tools as utt
2332
from tests.test_rop import break_op
@@ -231,3 +240,68 @@ def test_local_det_chol():
231240
f = function([X], [L, det_X, X])
232241
nodes = f.maker.fgraph.toposort()
233242
assert not any(isinstance(node, Det) for node in nodes)
243+
244+
245+
class TestBatchedVectorBSolveToMatrixBSolve:
246+
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
247+
248+
@staticmethod
249+
def any_vector_b_solve(fn):
250+
return any(
251+
(
252+
isinstance(node.op, Blockwise)
253+
and isinstance(node.op.core_op, SolveBase)
254+
and node.op.core_op.b_ndim == 1
255+
)
256+
for node in fn.maker.fgraph.apply_nodes
257+
)
258+
259+
@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
260+
def test_valid_cases(self, solve_op):
261+
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))
262+
263+
a = tensor(shape=(None, None))
264+
b = tensor(shape=(None, None, None))
265+
266+
if solve_op is cho_solve:
267+
# cho_solves expects a tuple (a, lower) as the first input
268+
out = solve_op((a, True), b, b_ndim=1)
269+
else:
270+
out = solve_op(a, b, b_ndim=1)
271+
272+
mode = get_default_mode().excluding(self.rewrite_name)
273+
ref_fn = pytensor.function([a, b], out, mode=mode)
274+
assert self.any_vector_b_solve(ref_fn)
275+
276+
mode = get_default_mode().including(self.rewrite_name)
277+
opt_fn = pytensor.function([a, b], out, mode=mode)
278+
assert not self.any_vector_b_solve(opt_fn)
279+
280+
test_a = rng.normal(size=(3, 3))
281+
test_b = rng.normal(size=(7, 5, 3))
282+
np.testing.assert_allclose(
283+
opt_fn(test_a, test_b),
284+
ref_fn(test_a, test_b),
285+
)
286+
287+
def test_invalid_batched_a(self):
288+
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))
289+
290+
# Rewrite is not applicable if a has batched dims
291+
a = tensor(shape=(None, None, None))
292+
b = tensor(shape=(None, None, None))
293+
294+
out = solve(a, b, b_ndim=1)
295+
296+
mode = get_default_mode().including(self.rewrite_name)
297+
opt_fn = pytensor.function([a, b], out, mode=mode)
298+
assert self.any_vector_b_solve(opt_fn)
299+
300+
ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")
301+
302+
test_a = rng.normal(size=(5, 3, 3))
303+
test_b = rng.normal(size=(7, 5, 3))
304+
np.testing.assert_allclose(
305+
opt_fn(test_a, test_b),
306+
ref_fn(test_a, test_b),
307+
)

0 commit comments

Comments
 (0)