Skip to content

Commit a30fcc1

Browse files
committed
Add specialization rewrite for solve with batched b
1 parent 7bb18f3 commit a30fcc1

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import cast
33

4-
from pytensor.graph.rewriting.basic import node_rewriter
4+
from pytensor.graph.rewriting.basic import node_rewriter, copy_stack_trace
55
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
66
from pytensor.tensor.blas import Dot22
77
from pytensor.tensor.blockwise import Blockwise
@@ -13,8 +13,7 @@
1313
register_specialize,
1414
register_stabilize,
1515
)
16-
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
17-
16+
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular, SolveBase
1817

1918
logger = logging.getLogger(__name__)
2019

@@ -131,6 +130,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
131130
]
132131

133132

133+
@register_specialize
134+
@node_rewriter([Blockwise])
135+
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
136+
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
137+
138+
This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
139+
Only the last two dimensions are swapped.
140+
"""
141+
core_op = node.op.core_op
142+
143+
if not isinstance(core_op, SolveBase):
144+
return None
145+
146+
if node.op.core_op.b_ndim != 1:
147+
return None
148+
149+
[a, b] = node.inputs
150+
151+
# Check `b` is actually batched
152+
if b.type.ndim == 1:
153+
return None
154+
155+
# Check `a` is a matrix (possibly with degenerate dims on the left)
156+
a_bcast_batch_dims = a.type.broadcastable[:-2]
157+
if not all(a_bcast_batch_dims):
158+
return None
159+
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
160+
elif len(a_bcast_batch_dims):
161+
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))
162+
163+
# Recreate solve Op with b_ndim=2
164+
props = core_op._props_dict()
165+
props["b_ndim"] = 2
166+
new_core_op = type(core_op)(**props)
167+
matrix_b_solve = Blockwise(new_core_op)
168+
169+
# Apply the rewrite
170+
new_solve = _T(matrix_b_solve(a, _T(b)))
171+
172+
old_solve = node.outputs[0]
173+
copy_stack_trace(old_solve, new_solve)
174+
175+
return [new_solve]
176+
177+
178+
134179
@register_canonicalize
135180
@register_stabilize
136181
@register_specialize

tests/tensor/rewriting/test_linalg.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from pytensor.tensor.math import _allclose, dot, matmul
1818
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
1919
from pytensor.tensor.rewriting.linalg import inv_as_solve
20-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
20+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve, cho_solve, solve_triangular, \
21+
SolveBase
2122
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
2223
from tests import unittest_tools as utt
2324
from tests.test_rop import break_op
@@ -231,3 +232,71 @@ def test_local_det_chol():
231232
f = function([X], [L, det_X, X])
232233
nodes = f.maker.fgraph.toposort()
233234
assert not any(isinstance(node, Det) for node in nodes)
235+
236+
237+
class TestBatchedVectorBSolveToMatrixBSolve:
238+
239+
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
240+
241+
@staticmethod
242+
def any_vector_b_solve(fn):
243+
return any(
244+
(
245+
isinstance(node.op, Blockwise)
246+
and isinstance(node.op.core_op, SolveBase)
247+
and node.op.core_op.b_ndim == 1
248+
)
249+
for node in fn.maker.fgraph.apply_nodes
250+
)
251+
252+
@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
253+
def test_valid_cases(self, solve_op):
254+
255+
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))
256+
257+
a = tensor(shape=(None, None))
258+
b = tensor(shape=(None, None, None))
259+
260+
if solve_op is cho_solve:
261+
# cho_solves expects a tuple (a, lower) as the first input
262+
out = solve_op((a, True), b, b_ndim=1)
263+
else:
264+
out = solve_op(a, b, b_ndim=1)
265+
266+
mode = get_default_mode().excluding(self.rewrite_name)
267+
ref_fn = pytensor.function([a, b], out, mode=mode)
268+
assert self.any_vector_b_solve(ref_fn)
269+
270+
mode = get_default_mode().including(self.rewrite_name)
271+
opt_fn = pytensor.function([a, b], out, mode=mode)
272+
assert not self.any_vector_b_solve(opt_fn)
273+
274+
test_a = rng.normal(size=(3, 3))
275+
test_b = rng.normal(size=(7, 5, 3))
276+
np.testing.assert_allclose(
277+
opt_fn(test_a, test_b),
278+
ref_fn(test_a, test_b),
279+
)
280+
281+
def test_invalid_batched_a(self):
282+
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))
283+
284+
# Rewrite is not applicable if a has batched dims
285+
a = tensor(shape=(None, None, None))
286+
b = tensor(shape=(None, None, None))
287+
288+
out = solve(a, b, b_ndim=1)
289+
290+
mode = get_default_mode().including(self.rewrite_name)
291+
opt_fn = pytensor.function([a, b], out, mode=mode)
292+
assert self.any_vector_b_solve(opt_fn)
293+
294+
ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")
295+
296+
test_a = rng.normal(size=(5, 3, 3))
297+
test_b = rng.normal(size=(7, 5, 3))
298+
np.testing.assert_allclose(
299+
opt_fn(test_a, test_b),
300+
ref_fn(test_a, test_b),
301+
)
302+

0 commit comments

Comments
 (0)