File tree Expand file tree Collapse file tree 2 files changed +46
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +46
-0
lines changed Original file line number Diff line number Diff line change @@ -821,6 +821,30 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
821821 return [eye_input / non_eye_input ]
822822
823823
824+ @register_canonicalize
825+ @register_stabilize
826+ @node_rewriter ([blockwise_of (BlockDiagonal )])
827+ def remove_useless_blockdiag (fgraph , node ):
828+ """
829+ This rewrite takes advantage of the fact that block_diag with only one input is the input itself.
830+
831+ Parameters
832+ ----------
833+ fgraph: FunctionGraph
834+ Function graph being optimized
835+ node: Apply
836+ Node of the function graph to be optimized
837+
838+ Returns
839+ -------
840+ list of Variable, optional
841+ List of optimized variables, or None if no optimization was performed
842+ """
843+ if len (node .inputs ) != 1 :
844+ return None
845+ return [node .inputs [0 ]]
846+
847+
824848@register_canonicalize
825849@register_stabilize
826850@node_rewriter ([ExtractDiag ])
Original file line number Diff line number Diff line change @@ -786,6 +786,28 @@ def test_det_blockdiag_rewrite():
786786 )
787787
788788
789+ def test_remove_useless_blockdiag_rewrite ():
790+ A = pt .matrix ("A" , shape = (2 , 2 ))
791+
792+ # Use the raw Op to dodge the eager return in pt.linalg.block_diag when n_inputs=1
793+ bd = Blockwise (BlockDiagonal (n_inputs = 1 ))(A )
794+ bd .dprint ()
795+ fg = FunctionGraph (inputs = [A ], outputs = [bd ])
796+ assert (
797+ any (
798+ isinstance (node .op , BlockDiagonal )
799+ or (
800+ isinstance (node .op , Blockwise )
801+ and isinstance (node .op .core_op , BlockDiagonal )
802+ )
803+ )
804+ for node in fg .toposort ()
805+ )
806+
807+ rewritten = rewrite_graph (bd , include = ("canonicalize" ,))
808+ utt .assert_equal_computations ([rewritten ], [A ])
809+
810+
789811def test_slogdet_blockdiag_rewrite ():
790812 n_matrices = 10
791813 matrix_size = (5 , 5 )
You can’t perform that action at this time.
0 commit comments