Skip to content

Commit 29cbdb0

Browse files
Eagerly noop single block_diag with single input
1 parent a17ce6c commit 29cbdb0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/tensor/slinalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,9 @@ def block_diag(*matrices: TensorVariable):
14021402
[0, 0, 5, 6],
14031403
[0, 0, 7, 8]])
14041404
"""
1405+
if len(matrices) == 1:
1406+
return matrices[0]
1407+
14051408
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
14061409
return _block_diagonal_matrix(*matrices)
14071410

0 commit comments

Comments
 (0)