Skip to content

Added rewrites involving block diagonal matrices #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Eye,
TensorVariable,
concatenate,
diag,
diagonal,
)
from pytensor.tensor.blas import Dot22
Expand All @@ -29,6 +32,7 @@
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)

return [eye_input / non_eye_input]


@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_blockdiag(fgraph, node):
"""
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.

diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
submatrices = potential_block_diag.inputs
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]

return [concatenate(submatrices_diag)]


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def rewrite_det_blockdiag(fgraph, node):
"""
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.

det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]

return [prod(det_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those

slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
89 changes: 89 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,92 @@ def test_inv_diag_from_diag(inv_op):
atol=ATOL,
rtol=RTOL,
)


def test_diag_blockdiag_rewrite():
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
diag_output = pt.diag(bd_output)
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
diag_output_test = np.diag(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
diag_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_det_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
det_output = pt.linalg.det(bd_output)
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
det_output_test = np.linalg.det(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
det_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_slogdet_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
f_rewritten = function(
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
)

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Loading