File tree 2 files changed +51
-0
lines changed
pytensor/tensor/rewriting
2 files changed +51
-0
lines changed Original file line number Diff line number Diff line change 11
11
from pytensor .scalar .basic import Mul
12
12
from pytensor .tensor .basic import (
13
13
AllocDiag ,
14
+ ExtractDiag ,
14
15
Eye ,
15
16
TensorVariable ,
17
+ concatenate ,
18
+ diag ,
16
19
diagonal ,
17
20
)
18
21
from pytensor .tensor .blas import Dot22
@@ -611,3 +614,24 @@ def rewrite_inv_inv(fgraph, node):
611
614
):
612
615
return None
613
616
return [potential_inner_inv .inputs [0 ]]
617
+
618
+
619
+ @register_canonicalize
620
+ @register_stabilize
621
+ @node_rewriter ([ExtractDiag ])
622
+ def rewrite_diag_blockdiag (fgraph , node ):
623
+ # Check for inner block_diag operation
624
+ potential_blockdiag = node .inputs [0 ].owner
625
+ if not (
626
+ potential_blockdiag
627
+ and isinstance (potential_blockdiag .op , Blockwise )
628
+ and isinstance (potential_blockdiag .op .core_op , BlockDiagonal )
629
+ ):
630
+ return None
631
+
632
+ # Find the composing sub_matrices
633
+ submatrices = potential_blockdiag .inputs
634
+ submatrices_diag = [diag (submatrices [i ]) for i in range (len (submatrices ))]
635
+ output = [concatenate (submatrices_diag )]
636
+
637
+ return output
Original file line number Diff line number Diff line change @@ -568,3 +568,30 @@ def get_pt_function(x, op_name):
568
568
op2 = get_pt_function (op1 , inv_op_2 )
569
569
rewritten_out = rewrite_graph (op2 )
570
570
assert rewritten_out == x
571
+
572
+
573
+ def test_diag_blockdiag_rewrite ():
574
+ n_matrices = 100
575
+ matrix_size = (5 , 5 )
576
+ sub_matrices = pt .tensor ("sub_matrices" , shape = (n_matrices , * matrix_size ))
577
+ bd_output = pt .linalg .block_diag (* [sub_matrices [i ] for i in range (n_matrices )])
578
+ diag_output = pt .diag (bd_output )
579
+ f_rewritten = function ([sub_matrices ], diag_output , mode = "FAST_RUN" )
580
+
581
+ # Rewrite Test
582
+ nodes = f_rewritten .maker .fgraph .apply_nodes
583
+ assert not any (isinstance (node .op , BlockDiagonal ) for node in nodes )
584
+
585
+ # Value Test
586
+ sub_matrices_test = np .random .rand (n_matrices , * matrix_size )
587
+ bd_output_test = scipy .linalg .block_diag (
588
+ * [sub_matrices_test [i ] for i in range (n_matrices )]
589
+ )
590
+ diag_output_test = np .diag (bd_output_test )
591
+ rewritten_val = f_rewritten (sub_matrices_test )
592
+ assert_allclose (
593
+ diag_output_test ,
594
+ rewritten_val ,
595
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
596
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
597
+ )
You can’t perform that action at this time.
0 commit comments