Skip to content

Commit 8b2ef28

Browse files
committed
added rewrite for diag(block_diag)
1 parent 7fffec6 commit 8b2ef28

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from pytensor.scalar.basic import Mul
1212
from pytensor.tensor.basic import (
1313
AllocDiag,
14+
ExtractDiag,
1415
Eye,
1516
TensorVariable,
17+
concatenate,
18+
diag,
1619
diagonal,
1720
)
1821
from pytensor.tensor.blas import Dot22
@@ -611,3 +614,24 @@ def rewrite_inv_inv(fgraph, node):
611614
):
612615
return None
613616
return [potential_inner_inv.inputs[0]]
617+
618+
619+
@register_canonicalize
620+
@register_stabilize
621+
@node_rewriter([ExtractDiag])
622+
def rewrite_diag_blockdiag(fgraph, node):
623+
# Check for inner block_diag operation
624+
potential_blockdiag = node.inputs[0].owner
625+
if not (
626+
potential_blockdiag
627+
and isinstance(potential_blockdiag.op, Blockwise)
628+
and isinstance(potential_blockdiag.op.core_op, BlockDiagonal)
629+
):
630+
return None
631+
632+
# Find the composing sub_matrices
633+
submatrices = potential_blockdiag.inputs
634+
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
635+
output = [concatenate(submatrices_diag)]
636+
637+
return output

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,30 @@ def get_pt_function(x, op_name):
568568
op2 = get_pt_function(op1, inv_op_2)
569569
rewritten_out = rewrite_graph(op2)
570570
assert rewritten_out == x
571+
572+
573+
def test_diag_blockdiag_rewrite():
574+
n_matrices = 100
575+
matrix_size = (5, 5)
576+
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
577+
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
578+
diag_output = pt.diag(bd_output)
579+
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")
580+
581+
# Rewrite Test
582+
nodes = f_rewritten.maker.fgraph.apply_nodes
583+
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
584+
585+
# Value Test
586+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
587+
bd_output_test = scipy.linalg.block_diag(
588+
*[sub_matrices_test[i] for i in range(n_matrices)]
589+
)
590+
diag_output_test = np.diag(bd_output_test)
591+
rewritten_val = f_rewritten(sub_matrices_test)
592+
assert_allclose(
593+
diag_output_test,
594+
rewritten_val,
595+
atol=1e-3 if config.floatX == "float32" else 1e-8,
596+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
597+
)

0 commit comments

Comments
 (0)