Skip to content

Commit ee050e5

Browse files
Add remove_useless_block_diag rewrite
1 parent 29cbdb0 commit ee050e5

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,30 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
821821
return [eye_input / non_eye_input]
822822

823823

824+
@register_canonicalize
825+
@register_stabilize
826+
@node_rewriter([blockwise_of(BlockDiagonal)])
827+
def remove_useless_blockdiag(fgraph, node):
828+
"""
829+
This rewrite takes advantage of the fact that block_diag with only one input is the input itself.
830+
831+
Parameters
832+
----------
833+
fgraph: FunctionGraph
834+
Function graph being optimized
835+
node: Apply
836+
Node of the function graph to be optimized
837+
838+
Returns
839+
-------
840+
list of Variable, optional
841+
List of optimized variables, or None if no optimization was performed
842+
"""
843+
if len(node.inputs) != 1:
844+
return None
845+
return [node.inputs[0]]
846+
847+
824848
@register_canonicalize
825849
@register_stabilize
826850
@node_rewriter([ExtractDiag])

tests/tensor/rewriting/test_linalg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,28 @@ def test_det_blockdiag_rewrite():
786786
)
787787

788788

789+
def test_remove_useless_blockdiag_rewrite():
790+
A = pt.matrix("A", shape=(2, 2))
791+
792+
# Use the raw Op to dodge the eager return in pt.linalg.block_diag when n_inputs=1
793+
bd = Blockwise(BlockDiagonal(n_inputs=1))(A)
794+
bd.dprint()
795+
fg = FunctionGraph(inputs=[A], outputs=[bd])
796+
assert (
797+
any(
798+
isinstance(node.op, BlockDiagonal)
799+
or (
800+
isinstance(node.op, Blockwise)
801+
and isinstance(node.op.core_op, BlockDiagonal)
802+
)
803+
)
804+
for node in fg.toposort()
805+
)
806+
807+
rewritten = rewrite_graph(bd, include=("canonicalize",))
808+
utt.assert_equal_computations([rewritten], [A])
809+
810+
789811
def test_slogdet_blockdiag_rewrite():
790812
n_matrices = 10
791813
matrix_size = (5, 5)

0 commit comments

Comments
 (0)